From d19b04464355c5b0b282aee38327a540d8d2163d Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 17:41:41 +0200 Subject: [PATCH 01/14] New feature: streaming voice activity detection. Pipeline name changes --- src/diart/__init__.py | 10 +- src/diart/blocks/__init__.py | 5 +- src/diart/blocks/base.py | 92 ++++++++++++++ src/diart/blocks/config.py | 153 ----------------------- src/diart/blocks/diarization.py | 145 ++++++++++++++++++---- src/diart/blocks/vad.py | 208 ++++++++++++++++++++++++++++++++ src/diart/console/benchmark.py | 12 +- src/diart/console/client.py | 6 +- src/diart/console/serve.py | 19 +-- src/diart/console/stream.py | 19 +-- src/diart/console/tune.py | 26 +++- src/diart/inference.py | 86 +++++++------ src/diart/optim.py | 56 ++++----- src/diart/sinks.py | 47 +++++--- src/diart/sources.py | 2 +- src/diart/utils.py | 16 ++- 16 files changed, 605 insertions(+), 297 deletions(-) create mode 100644 src/diart/blocks/base.py delete mode 100644 src/diart/blocks/config.py create mode 100644 src/diart/blocks/vad.py diff --git a/src/diart/__init__.py b/src/diart/__init__.py index c9692638..e29287a0 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1,6 +1,8 @@ from .blocks import ( - OnlineSpeakerDiarization, - BasePipeline, - PipelineConfig, - BasePipelineConfig, + SpeakerDiarization, + StreamingPipeline, + SpeakerDiarizationConfig, + StreamingConfig, + VoiceActivityDetection, + VoiceActivityDetectionConfig, ) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index 59a6ef36..e6e8c479 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -13,6 +13,7 @@ OverlapAwareSpeakerEmbedding, ) from .segmentation import SpeakerSegmentation -from .diarization import OnlineSpeakerDiarization, BasePipeline -from .config import BasePipelineConfig, PipelineConfig +from .diarization import SpeakerDiarization, SpeakerDiarizationConfig +from .base import StreamingConfig, StreamingPipeline from .utils import Binarize, Resample, AdjustVolume +from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py new file mode 100644 index 00000000..28f313eb --- /dev/null +++ b/src/diart/blocks/base.py @@ -0,0 +1,92 @@ +from typing import Any, Tuple, Sequence, Text +from dataclasses import dataclass + +import numpy as np +from pyannote.core import SlidingWindowFeature +from pyannote.metrics.base import BaseMetric + +from .. import utils +from ..audio import FilePath, AudioLoader + + +@dataclass +class HyperParameter: + name: Text + low: float + high: float + + @staticmethod + def from_name(name: Text) -> 'HyperParameter': + if name == "tau_active": + return TauActive + if name == "rho_update": + return RhoUpdate + if name == "delta_new": + return DeltaNew + raise ValueError(f"Hyper-parameter '{name}' not recognized") + + +TauActive = HyperParameter("tau_active", low=0, high=1) +RhoUpdate = HyperParameter("rho_update", low=0, high=1) +DeltaNew = HyperParameter("delta_new", low=0, high=2) + + +class StreamingConfig: + @property + def duration(self) -> float: + raise NotImplementedError + + @property + def step(self) -> float: + raise NotImplementedError + + @property + def latency(self) -> float: + raise NotImplementedError + + @property + def sample_rate(self) -> int: + raise NotImplementedError + + @staticmethod + def from_dict(data: Any) -> 'StreamingConfig': + raise NotImplementedError + + def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: + file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) + right = utils.get_padding_right(self.latency, self.step) + left = utils.get_padding_left(file_duration + right, self.duration) + return left, right + + def optimal_block_size(self) -> int: + return int(np.rint(self.step * self.sample_rate)) + + +class StreamingPipeline: + @staticmethod + def get_config_class() -> type: + raise NotImplementedError + + @staticmethod + def suggest_metric() -> BaseMetric: + raise NotImplementedError + + @staticmethod + def hyper_parameters() -> Sequence[HyperParameter]: + raise NotImplementedError + + @property + def config(self) -> StreamingConfig: + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def set_timestamp_shift(self, shift: float): + raise NotImplementedError + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature] + ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: + raise NotImplementedError diff --git a/src/diart/blocks/config.py b/src/diart/blocks/config.py deleted file mode 100644 index d8e2a656..00000000 --- a/src/diart/blocks/config.py +++ /dev/null @@ -1,153 +0,0 @@ -from typing import Any, Optional, Union, Tuple - -import numpy as np -import torch -from typing_extensions import Literal - -from .. import models as m -from .. import utils -from ..audio import FilePath, AudioLoader - - -class BasePipelineConfig: - @property - def duration(self) -> float: - raise NotImplementedError - - @property - def step(self) -> float: - raise NotImplementedError - - @property - def latency(self) -> float: - raise NotImplementedError - - @property - def sample_rate(self) -> int: - raise NotImplementedError - - @staticmethod - def from_dict(data: Any) -> 'BasePipelineConfig': - raise NotImplementedError - - def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: - file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) - right = utils.get_padding_right(self.latency, self.step) - left = utils.get_padding_left(file_duration + right, self.duration) - return left, right - - def optimal_block_size(self) -> int: - return int(np.rint(self.step * self.sample_rate)) - - -class PipelineConfig(BasePipelineConfig): - def __init__( - self, - segmentation: Optional[m.SegmentationModel] = None, - embedding: Optional[m.EmbeddingModel] = None, - duration: Optional[float] = None, - step: float = 0.5, - latency: Optional[Union[float, Literal["max", "min"]]] = None, - tau_active: float = 0.6, - rho_update: float = 0.3, - delta_new: float = 1, - gamma: float = 3, - beta: float = 10, - max_speakers: int = 20, - device: Optional[torch.device] = None, - **kwargs, - ): - # Default segmentation model is pyannote/segmentation - self.segmentation = segmentation - if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") - - # Default duration is the one given by the segmentation model - self._duration = duration - - # Expected sample rate is given by the segmentation model - self._sample_rate: Optional[int] = None - - # Default embedding model is pyannote/embedding - self.embedding = embedding - if self.embedding is None: - self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") - - # Latency defaults to the step duration - self._step = step - self._latency = latency - if self._latency is None or self._latency == "min": - self._latency = self._step - elif self._latency == "max": - self._latency = self._duration - - self.tau_active = tau_active - self.rho_update = rho_update - self.delta_new = delta_new - self.gamma = gamma - self.beta = beta - self.max_speakers = max_speakers - - self.device = device - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - @staticmethod - def from_dict(data: Any) -> 'PipelineConfig': - # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None - device = utils.get(data, "device", None) - if device is None: - device = torch.device("cpu") if utils.get(data, "cpu", False) else None - - # Instantiate models - hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) - segmentation = utils.get(data, "segmentation", "pyannote/segmentation") - segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) - embedding = utils.get(data, "embedding", "pyannote/embedding") - embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token) - - # Hyper-parameters and their aliases - tau = utils.get(data, "tau_active", None) - if tau is None: - tau = utils.get(data, "tau", 0.6) - rho = utils.get(data, "rho_update", None) - if rho is None: - rho = utils.get(data, "rho", 0.3) - delta = utils.get(data, "delta_new", None) - if delta is None: - delta = utils.get(data, "delta", 1) - - return PipelineConfig( - segmentation=segmentation, - embedding=embedding, - duration=utils.get(data, "duration", None), - step=utils.get(data, "step", 0.5), - latency=utils.get(data, "latency", None), - tau_active=tau, - rho_update=rho, - delta_new=delta, - gamma=utils.get(data, "gamma", 3), - beta=utils.get(data, "beta", 10), - max_speakers=utils.get(data, "max_speakers", 20), - device=device, - ) - - @property - def duration(self) -> float: - if self._duration is None: - self._duration = self.segmentation.duration - return self._duration - - @property - def step(self) -> float: - return self._step - - @property - def latency(self) -> float: - return self._latency - - @property - def sample_rate(self) -> int: - if self._sample_rate is None: - self._sample_rate = self.segmentation.sample_rate - return self._sample_rate diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 7f0e162c..f2a25119 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -1,42 +1,137 @@ -from typing import Optional, Tuple, Sequence +from typing import Optional, Tuple, Sequence, Union, Any import numpy as np import torch from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.diarization import DiarizationErrorRate +from typing_extensions import Literal from .aggregation import DelayedAggregation +from . import base from .clustering import OnlineSpeakerClustering from .embedding import OverlapAwareSpeakerEmbedding from .segmentation import SpeakerSegmentation from .utils import Binarize -from .config import BasePipelineConfig, PipelineConfig +from .. import models as m +from .. import utils -class BasePipeline: +class SpeakerDiarizationConfig(base.StreamingConfig): + def __init__( + self, + segmentation: Optional[m.SegmentationModel] = None, + embedding: Optional[m.EmbeddingModel] = None, + duration: Optional[float] = None, + step: float = 0.5, + latency: Optional[Union[float, Literal["max", "min"]]] = None, + tau_active: float = 0.6, + rho_update: float = 0.3, + delta_new: float = 1, + gamma: float = 3, + beta: float = 10, + max_speakers: int = 20, + device: Optional[torch.device] = None, + **kwargs, + ): + # Default segmentation model is pyannote/segmentation + self.segmentation = segmentation + if self.segmentation is None: + self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + + self._duration = duration + self._sample_rate: Optional[int] = None + + # Default embedding model is pyannote/embedding + self.embedding = embedding + if self.embedding is None: + self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") + + # Latency defaults to the step duration + self._step = step + self._latency = latency + if self._latency is None or self._latency == "min": + self._latency = self._step + elif self._latency == "max": + self._latency = self._duration + + self.tau_active = tau_active + self.rho_update = rho_update + self.delta_new = delta_new + self.gamma = gamma + self.beta = beta + self.max_speakers = max_speakers + + self.device = device + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + @staticmethod - def get_config_class() -> type: - raise NotImplementedError + def from_dict(data: Any) -> 'SpeakerDiarizationConfig': + # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None + device = utils.get(data, "device", None) + if device is None: + device = torch.device("cpu") if utils.get(data, "cpu", False) else None + + # Instantiate models + hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) + segmentation = utils.get(data, "segmentation", "pyannote/segmentation") + segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) + embedding = utils.get(data, "embedding", "pyannote/embedding") + embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token) + + # Hyper-parameters and their aliases + tau = utils.get(data, "tau_active", None) + if tau is None: + tau = utils.get(data, "tau", 0.6) + rho = utils.get(data, "rho_update", None) + if rho is None: + rho = utils.get(data, "rho", 0.3) + delta = utils.get(data, "delta_new", None) + if delta is None: + delta = utils.get(data, "delta", 1) + + return SpeakerDiarizationConfig( + segmentation=segmentation, + embedding=embedding, + duration=utils.get(data, "duration", None), + step=utils.get(data, "step", 0.5), + latency=utils.get(data, "latency", None), + tau_active=tau, + rho_update=rho, + delta_new=delta, + gamma=utils.get(data, "gamma", 3), + beta=utils.get(data, "beta", 10), + max_speakers=utils.get(data, "max_speakers", 20), + device=device, + ) @property - def config(self) -> BasePipelineConfig: - raise NotImplementedError + def duration(self) -> float: + # Default duration is the one given by the segmentation model + if self._duration is None: + self._duration = self.segmentation.duration + return self._duration - def reset(self): - raise NotImplementedError + @property + def step(self) -> float: + return self._step - def set_timestamp_shift(self, shift: float): - raise NotImplementedError + @property + def latency(self) -> float: + return self._latency - def __call__( - self, - waveforms: Sequence[SlidingWindowFeature] - ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: - raise NotImplementedError + @property + def sample_rate(self) -> int: + # Expected sample rate is given by the segmentation model + if self._sample_rate is None: + self._sample_rate = self.segmentation.sample_rate + return self._sample_rate -class OnlineSpeakerDiarization(BasePipeline): - def __init__(self, config: Optional[PipelineConfig] = None): - self._config = PipelineConfig() if config is None else config +class SpeakerDiarization(base.StreamingPipeline): + def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): + self._config = SpeakerDiarizationConfig() if config is None else config msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" assert self._config.step <= self._config.latency <= self._config.duration, msg @@ -67,10 +162,18 @@ def __init__(self, config: Optional[PipelineConfig] = None): @staticmethod def get_config_class() -> type: - return PipelineConfig + return SpeakerDiarizationConfig + + @staticmethod + def suggest_metric() -> BaseMetric: + return DiarizationErrorRate(collar=0, skip_overlap=False) + + @staticmethod + def hyper_parameters() -> Sequence[base.HyperParameter]: + return [base.TauActive, base.RhoUpdate, base.DeltaNew] @property - def config(self) -> PipelineConfig: + def config(self) -> SpeakerDiarizationConfig: return self._config def set_timestamp_shift(self, shift: float): diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py new file mode 100644 index 00000000..def833b6 --- /dev/null +++ b/src/diart/blocks/vad.py @@ -0,0 +1,208 @@ +from typing import Any, Optional, Union, Sequence, Tuple + +import numpy as np +import torch +from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.detection import DetectionErrorRate +from typing_extensions import Literal + +from .aggregation import DelayedAggregation +from . import base +from .segmentation import SpeakerSegmentation +from .utils import Binarize +from .. import models as m +from .. import utils + + +class VoiceActivityDetectionConfig(base.StreamingConfig): + def __init__( + self, + segmentation: Optional[m.SegmentationModel] = None, + duration: Optional[float] = None, + step: float = 0.5, + latency: Optional[Union[float, Literal["max", "min"]]] = None, + tau_active: float = 0.6, + device: Optional[torch.device] = None, + **kwargs, + ): + # Default segmentation model is pyannote/segmentation + self.segmentation = segmentation + if self.segmentation is None: + self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + + self._duration = duration + self._step = step + self._sample_rate: Optional[int] = None + + # Latency defaults to the step duration + self._latency = latency + if self._latency is None or self._latency == "min": + self._latency = self._step + elif self._latency == "max": + self._latency = self._duration + + self.tau_active = tau_active + self.device = device + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + @property + def duration(self) -> float: + # Default duration is the one given by the segmentation model + if self._duration is None: + self._duration = self.segmentation.duration + return self._duration + + @property + def step(self) -> float: + return self._step + + @property + def latency(self) -> float: + return self._latency + + @property + def sample_rate(self) -> int: + # Expected sample rate is given by the segmentation model + if self._sample_rate is None: + self._sample_rate = self.segmentation.sample_rate + return self._sample_rate + + @staticmethod + def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': + # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None + device = utils.get(data, "device", None) + if device is None: + device = torch.device("cpu") if utils.get(data, "cpu", False) else None + + # Instantiate segmentation model + hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) + segmentation = utils.get(data, "segmentation", "pyannote/segmentation") + segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) + + # Tau active and its alias + tau = utils.get(data, "tau_active", None) + if tau is None: + tau = utils.get(data, "tau", 0.6) + + return VoiceActivityDetectionConfig( + segmentation=segmentation, + duration=utils.get(data, "duration", None), + step=utils.get(data, "step", 0.5), + latency=utils.get(data, "latency", None), + tau_active=tau, + device=device, + ) + + +class VoiceActivityDetection(base.StreamingPipeline): + def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): + self._config = VoiceActivityDetectionConfig() if config is None else config + + msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" + assert self._config.step <= self._config.latency <= self._config.duration, msg + + self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device) + self.pred_aggregation = DelayedAggregation( + self._config.step, + self._config.latency, + strategy="hamming", + cropping_mode="loose", + ) + self.audio_aggregation = DelayedAggregation( + self._config.step, + self._config.latency, + strategy="first", + cropping_mode="center", + ) + self.binarize = Binarize(self._config.tau_active) + + # Internal state, handle with care + self.timestamp_shift = 0 + self.chunk_buffer, self.pred_buffer = [], [] + + @staticmethod + def get_config_class() -> type: + return VoiceActivityDetectionConfig + + @staticmethod + def suggest_metric() -> BaseMetric: + return DetectionErrorRate(collar=0, skip_overlap=False) + + @staticmethod + def hyper_parameters() -> Sequence[base.HyperParameter]: + return [base.TauActive] + + @property + def config(self) -> base.StreamingConfig: + return self._config + + def reset(self): + self.set_timestamp_shift(0) + self.chunk_buffer, self.pred_buffer = [], [] + + def set_timestamp_shift(self, shift: float): + self.timestamp_shift = shift + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature], + ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: + batch_size = len(waveforms) + msg = "Pipeline expected at least 1 input" + assert batch_size >= 1, msg + + # Create batch from chunk sequence, shape (batch, samples, channels) + batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) + + expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) + msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" + assert batch.shape[1] == expected_num_samples, msg + + # Extract segmentation + segmentations = self.segmentation(batch) # shape (batch, frames, speakers) + voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[0] # shape (batch, frames, 1) + + seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] + + outputs = [] + for wav, vad in zip(waveforms, voice_detection): + # Add timestamps to segmentation + sw = SlidingWindow( + start=wav.extent.start, + duration=seg_resolution, + step=seg_resolution, + ) + vad = SlidingWindowFeature(vad.cpu().numpy(), sw) + + # Update sliding buffer + self.chunk_buffer.append(wav) + self.pred_buffer.append(vad) + + # Aggregate buffer outputs for this time step + agg_waveform = self.audio_aggregation(self.chunk_buffer) + agg_prediction = self.pred_aggregation(self.pred_buffer) + agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False) + + # Shift prediction timestamps if required + if self.timestamp_shift != 0: + shifted_agg_prediction = Timeline(uri=agg_prediction.uri) + for segment in agg_prediction: + new_segment = Segment( + segment.start + self.timestamp_shift, + segment.end + self.timestamp_shift, + ) + shifted_agg_prediction.add(new_segment) + agg_prediction = shifted_agg_prediction + + # Convert timeline into annotation with single speaker "speech" + agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech")) + outputs.append((agg_prediction, agg_waveform)) + + # Make place for new chunks in buffer if required + if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows: + self.chunk_buffer = self.chunk_buffer[1:] + self.pred_buffer = self.pred_buffer[1:] + + return outputs diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index b6a3f9ff..27d524c5 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -1,15 +1,17 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc import pandas as pd -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig +from diart import argdoc +from diart import utils from diart.inference import Benchmark, Parallelize def run(): parser = argparse.ArgumentParser() parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -34,6 +36,8 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() + pipeline_class = utils.get_pipeline_class(args.pipeline) + benchmark = Benchmark( args.root, args.reference, @@ -43,11 +47,11 @@ def run(): batch_size=args.batch_size, ) - config = PipelineConfig.from_dict(vars(args)) + config = pipeline_class.get_config_class().from_dict(vars(args)) if args.num_workers > 0: benchmark = Parallelize(benchmark, args.num_workers) - report = benchmark(OnlineSpeakerDiarization, config) + report = benchmark(pipeline_class, config) if args.output is not None and isinstance(report, pd.DataFrame): report.to_csv(args.output / "benchmark_report.csv") diff --git a/src/diart/console/client.py b/src/diart/console/client.py index 084dbc13..db4915fa 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -3,11 +3,11 @@ from threading import Thread from typing import Text, Optional -import diart.argdoc as argdoc -import diart.sources as src -import diart.utils as utils import numpy as np import rx.operators as ops +from diart import argdoc +from diart import sources as src +from diart import utils from websocket import WebSocket diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index 2f632d57..46bb9328 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -1,10 +1,10 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc -import diart.sources as src -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig -from diart.inference import RealTimeInference +from diart import argdoc +from diart import sources as src +from diart import utils +from diart.inference import StreamingInference from diart.sinks import RTTMWriter @@ -12,6 +12,8 @@ def run(): parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host") parser.add_argument("--port", default=7007, type=int, help="Server port") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -31,15 +33,16 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() - # Define online speaker diarization pipeline - config = PipelineConfig.from_dict(vars(args)) - pipeline = OnlineSpeakerDiarization(config) + # Resolve pipeline + pipeline_class = utils.get_pipeline_class(args.pipeline) + config = pipeline_class.get_config_class().from_dict(vars(args)) + pipeline = pipeline_class(config) # Create websocket audio source audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port) # Run online inference - inference = RealTimeInference( + inference = StreamingInference( pipeline, audio_source, batch_size=1, diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index d7218f07..e0c670c5 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -1,16 +1,18 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc -import diart.sources as src -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig -from diart.inference import RealTimeInference +from diart import argdoc +from diart import sources as src +from diart import utils +from diart.inference import StreamingInference from diart.sinks import RTTMWriter def run(): parser = argparse.ArgumentParser() parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -32,9 +34,10 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() - # Define online speaker diarization pipeline - config = PipelineConfig.from_dict(vars(args)) - pipeline = OnlineSpeakerDiarization(config) + # Resolve pipeline + pipeline_class = utils.get_pipeline_class(args.pipeline) + config = pipeline_class.get_config_class().from_dict(vars(args)) + pipeline = pipeline_class(config) # Manage audio source block_size = config.optimal_block_size() @@ -51,7 +54,7 @@ def run(): audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device) # Run online inference - inference = RealTimeInference( + inference = StreamingInference( pipeline, audio_source, batch_size=1, diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index 4ad8852a..a1f1b63a 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -1,10 +1,11 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc import optuna -from diart.blocks import PipelineConfig, OnlineSpeakerDiarization -from diart.optim import Optimizer, HyperParameter +from diart import argdoc +from diart import utils +from diart.blocks.base import HyperParameter +from diart.optim import Optimizer from optuna.samplers import TPESampler @@ -13,6 +14,8 @@ def run(): parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") parser.add_argument("--reference", required=True, type=str, help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -38,17 +41,28 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() + # Retrieve pipeline class + pipeline_class = utils.get_pipeline_class(args.pipeline) + # Create the base configuration for each trial - base_config = PipelineConfig.from_dict(vars(args)) + base_config = pipeline_class.get_config_class().from_dict(vars(args)) # Create hyper-parameters to optimize + possible_hparams = pipeline_class.hyper_parameters() hparams = [HyperParameter.from_name(name) for name in args.hparams] + hparams = [hp for hp in hparams if hp in possible_hparams] + if not hparams: + print( + f"No hyper-parameters to optimize. " + f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}" + ) + exit(1) # Use a custom storage if given if args.output is not None: msg = "Both `output` and `storage` were set, but only one was expected" assert args.storage is None, msg - args.output = Path(args.output) + args.output = Path(args.output).expanduser() args.output.mkdir(parents=True, exist_ok=True) study_or_path = args.output elif args.storage is not None: @@ -60,11 +74,11 @@ def run(): # Run optimization Optimizer( + pipeline_class=pipeline_class, speech_path=args.root, reference_path=args.reference, study_or_path=study_or_path, batch_size=args.batch_size, - pipeline_class=OnlineSpeakerDiarization, hparams=hparams, base_config=base_config, )(num_iter=args.num_iter, show_progress=True) diff --git a/src/diart/inference.py b/src/diart/inference.py index f4b65f5f..6afda89e 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -4,32 +4,33 @@ from traceback import print_exc from typing import Union, Text, Optional, Callable, Tuple, List -import diart.operators as dops -import diart.sources as src import numpy as np import pandas as pd import rx import rx.operators as ops import torch -from diart import utils -from diart.blocks import BasePipeline, Resample, BasePipelineConfig -from diart.progress import ProgressBar, RichProgressBar, TQDMProgressBar -from diart.sinks import DiarizationPredictionAccumulator, RealTimePlot, WindowClosedException from pyannote.core import Annotation, SlidingWindowFeature from pyannote.database.util import load_rttm -from pyannote.metrics.diarization import DiarizationErrorRate +from pyannote.metrics.base import BaseMetric from rx.core import Observer from tqdm import tqdm +from . import blocks +from . import operators as dops +from . import sources as src +from . import utils +from .progress import ProgressBar, RichProgressBar, TQDMProgressBar +from .sinks import PredictionAccumulator, StreamingPlot, WindowClosedException -class RealTimeInference: + +class StreamingInference: """Performs inference in real time given a pipeline and an audio source. Streams an audio source to an online speaker diarization pipeline. It allows users to attach a chain of operations in the form of hooks. Parameters ---------- - pipeline: BasePipeline + pipeline: StreamingPipeline Configured speaker diarization pipeline. source: AudioSource Audio source to be read and streamed. @@ -52,7 +53,7 @@ class RealTimeInference: """ def __init__( self, - pipeline: BasePipeline, + pipeline: blocks.StreamingPipeline, source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, @@ -66,7 +67,7 @@ def __init__( self.do_profile = do_profile self.do_plot = do_plot self.show_progress = show_progress - self.accumulator = DiarizationPredictionAccumulator(self.source.uri) + self.accumulator = PredictionAccumulator(self.source.uri) self.unit = "chunk" if self.batch_size == 1 else "batch" self._observers = [] @@ -102,7 +103,7 @@ def __init__( f"but pipeline's is {sample_rate}. Will resample." logging.warning(msg) self.stream = self.stream.pipe( - ops.map(Resample(self.source.sample_rate, sample_rate)) + ops.map(blocks.Resample(self.source.sample_rate, sample_rate)) ) # Add rx operators to manage the inputs and outputs of the pipeline @@ -202,7 +203,7 @@ def __call__(self) -> Annotation: latency=config.latency, sample_rate=config.sample_rate, ), - ops.do(RealTimePlot(config.duration, config.latency)), + ops.do(StreamingPlot(config.duration, config.latency)), ) observable.subscribe( on_error=self._handle_error, @@ -288,7 +289,7 @@ def get_file_paths(self) -> List[Path]: def run_single( self, - pipeline: BasePipeline, + pipeline: blocks.StreamingPipeline, filepath: Path, progress_bar: ProgressBar, ) -> Annotation: @@ -298,7 +299,7 @@ def run_single( Parameters ---------- - pipeline: BasePipeline + pipeline: StreamingPipeline Speaker diarization pipeline to run. filepath: Path Path to the target file. @@ -318,7 +319,7 @@ def run_single( pipeline.config.optimal_block_size(), ) pipeline.set_timestamp_shift(-padding[0]) - inference = RealTimeInference( + inference = StreamingInference( pipeline, source, self.batch_size, @@ -337,7 +338,11 @@ def run_single( return pred - def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[Annotation]]: + def evaluate( + self, + predictions: List[Annotation], + metric: BaseMetric, + ) -> Union[pd.DataFrame, List[Annotation]]: """If a reference path was provided, compute the diarization error rate of a list of predictions. @@ -345,6 +350,8 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An ---------- predictions: List[Annotation] Predictions to evaluate. + metric: BaseMetric + Evaluation metric from pyannote.metrics. Returns ------- @@ -353,8 +360,7 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An reference path was given. Otherwise return the same predictions. """ if self.reference_path is not None: - metric = DiarizationErrorRate(collar=0, skip_overlap=False) - progress_bar = TQDMProgressBar("Computing DER", leave=False) + progress_bar = TQDMProgressBar(f"Computing {metric.name}", leave=False) progress_bar.create(total=len(predictions), unit="file") progress_bar.start() for hyp in predictions: @@ -368,18 +374,22 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An def __call__( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.StreamingConfig, + metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files. - Notice that the internal state of the pipeline is reset before benchmarking. + The internal state of the pipeline is reset before benchmarking. Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated by each worker. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. + metric: Optional[BaseMetric] + Evaluation metric from pyannote.metrics. + Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- @@ -400,7 +410,8 @@ def __call__( progress = TQDMProgressBar(desc, leave=False, do_close=True) predictions.append(self.run_single(pipeline, filepath, progress)) - return self.evaluate(predictions) + metric = pipeline.suggest_metric() if metric is None else metric + return self.evaluate(predictions, metric) class Parallelize: @@ -426,20 +437,20 @@ def __init__( def run_single_job( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.StreamingConfig, filepath: Path, description: Text, - ): + ) -> Annotation: """Build and run a pipeline on a single file. Configure execution to show progress alongside parallel runs. Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. filepath: Path Path to the target file. description: Text @@ -463,7 +474,8 @@ def run_single_job( def __call__( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.StreamingConfig, + metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files in parallel. Each worker will build and run the pipeline on a different file. @@ -471,10 +483,13 @@ def __call__( Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated by each worker. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. + metric: Optional[BaseMetric] + Evaluation metric from pyannote.metrics. + Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- @@ -512,4 +527,5 @@ def __call__( predictions = [job.get() for job in jobs] # Evaluate results - return self.benchmark.evaluate(predictions) + metric = pipeline_class.suggest_metric() if metric is None else metric + return self.benchmark.evaluate(predictions, metric) diff --git a/src/diart/optim.py b/src/diart/optim.py index 05800a05..f7a96a6e 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -1,51 +1,32 @@ from collections import OrderedDict -from dataclasses import dataclass from pathlib import Path from typing import Sequence, Text, Optional, Union from optuna import TrialPruned, Study, create_study from optuna.samplers import TPESampler from optuna.trial import Trial, FrozenTrial +from pyannote.metrics.base import BaseMetric from tqdm import trange, tqdm +from typing_extensions import Literal +from . import blocks from .audio import FilePath -from .blocks import BasePipelineConfig, PipelineConfig, OnlineSpeakerDiarization from .inference import Benchmark -@dataclass -class HyperParameter: - name: Text - low: float - high: float - - @staticmethod - def from_name(name: Text) -> 'HyperParameter': - if name == "tau_active": - return TauActive - if name == "rho_update": - return RhoUpdate - if name == "delta_new": - return DeltaNew - raise ValueError(f"Hyper-parameter '{name}' not recognized") - - -TauActive = HyperParameter("tau_active", low=0, high=1) -RhoUpdate = HyperParameter("rho_update", low=0, high=1) -DeltaNew = HyperParameter("delta_new", low=0, high=2) - - class Optimizer: def __init__( self, + pipeline_class: type, speech_path: Union[Text, Path], reference_path: Union[Text, Path], study_or_path: Union[FilePath, Study], batch_size: int = 32, - pipeline_class: type = OnlineSpeakerDiarization, - hparams: Optional[Sequence[HyperParameter]] = None, - base_config: Optional[BasePipelineConfig] = None, + hparams: Optional[Sequence[blocks.base.HyperParameter]] = None, + base_config: Optional[blocks.StreamingConfig] = None, do_kickstart_hparams: bool = True, + metric: Optional[BaseMetric] = None, + direction: Literal["minimize", "maximize"] = "minimize", ): self.pipeline_class = pipeline_class # FIXME can we run this benchmark in parallel? @@ -58,15 +39,17 @@ def __init__( batch_size=batch_size, ) + self.metric = metric + self.direction = direction self.base_config = base_config self.do_kickstart_hparams = do_kickstart_hparams if self.base_config is None: - self.base_config = PipelineConfig() + self.base_config = self.pipeline_class.get_config_class()() self.do_kickstart_hparams = False self.hparams = hparams if self.hparams is None: - self.hparams = [TauActive, RhoUpdate, DeltaNew] + self.hparams = self.pipeline_class.hyper_parameters() # Make sure hyper-parameters exist in the configuration class given possible_hparams = vars(self.base_config) @@ -85,7 +68,7 @@ def __init__( storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"), sampler=TPESampler(), study_name=study_or_path.stem, - direction="minimize", + direction=self.direction, load_if_exists=True, ) else: @@ -105,7 +88,7 @@ def _callback(self, study: Study, trial: FrozenTrial): return self._progress.update(1) self._progress.set_description(f"Trial {trial.number + 1}") - values = {"best_der": study.best_value} + values = {"best_perf": study.best_value} for name, value in study.best_params.items(): values[f"best_{name}"] = value self._progress.set_postfix(OrderedDict(values)) @@ -125,11 +108,16 @@ def objective(self, trial: Trial) -> float: # Instantiate the new configuration for the trial config = self.base_config.__class__(**trial_config) + # Determine the evaluation metric + metric = self.metric + if metric is None: + metric = self.pipeline_class.suggest_metric() + # Run pipeline over the dataset - report = self.benchmark(self.pipeline_class, config) + report = self.benchmark(self.pipeline_class, config, metric) - # Extract DER from report - return report.loc["TOTAL", "diarization error rate"]["%"] + # Extract target metric from report + return report.loc["TOTAL", metric.name]["%"] def __call__(self, num_iter: int, show_progress: bool = True): self._progress = None diff --git a/src/diart/sinks.py b/src/diart/sinks.py index cf480bed..63c170d0 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -8,12 +8,14 @@ from rx.core import Observer from typing_extensions import Literal +from . import utils + class WindowClosedException(Exception): pass -def _extract_annotation(value: Union[Tuple, Annotation]) -> Annotation: +def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation: if isinstance(value, tuple): return value[0] if isinstance(value, Annotation): @@ -43,10 +45,11 @@ def patch(self): annotation.support(self.patch_collar).write_rttm(file) def on_next(self, value: Union[Tuple, Annotation]): - annotation = _extract_annotation(value) - annotation.uri = self.uri + prediction = _extract_prediction(value) + # Write prediction in RTTM format + prediction.uri = self.uri with open(self.path, 'a') as file: - annotation.write_rttm(file) + prediction.write_rttm(file) def on_error(self, error: Exception): self.patch() @@ -55,30 +58,30 @@ def on_completed(self): self.patch() -class DiarizationPredictionAccumulator(Observer): +class PredictionAccumulator(Observer): def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05): super().__init__() self.uri = uri self.patch_collar = patch_collar - self._annotation = None + self._prediction: Optional[Annotation] = None def patch(self): """Stitch same-speaker turns that are close to each other""" - if self._annotation is not None: - self._annotation = self._annotation.support(self.patch_collar) + if self._prediction is not None: + self._prediction = self._prediction.support(self.patch_collar) def get_prediction(self) -> Annotation: # Patch again in case this is called before on_completed self.patch() - return self._annotation + return self._prediction def on_next(self, value: Union[Tuple, Annotation]): - annotation = _extract_annotation(value) - annotation.uri = self.uri - if self._annotation is None: - self._annotation = annotation + prediction = _extract_prediction(value) + prediction.uri = self.uri + if self._prediction is None: + self._prediction = prediction else: - self._annotation.update(annotation) + self._prediction.update(prediction) def on_error(self, error: Exception): self.patch() @@ -87,7 +90,7 @@ def on_completed(self): self.patch() -class RealTimePlot(Observer): +class StreamingPlot(Observer): def __init__( self, duration: float, @@ -134,11 +137,15 @@ def get_plot_bounds(self, real_time: float) -> Segment: start_time = max(0., end_time - self.window_duration) return Segment(start_time, end_time) - def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): + def on_next( + self, + values: Tuple[Annotation, SlidingWindowFeature, float] + ): if self.window_closed: raise WindowClosedException prediction, waveform, real_time = values + # Initialize figure if first call if self.figure is None: self._init_figure() @@ -147,15 +154,21 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): # Set plot bounds notebook.crop = self.get_plot_bounds(real_time) - # Plot current values + # Align prediction and reference if possible if self.reference is not None: metric = DiarizationErrorRate() mapping = metric.optimal_mapping(self.reference, prediction) prediction.rename_labels(mapping=mapping, copy=False) + + # Plot prediction notebook.plot_annotation(prediction, self.axs[0]) self.axs[0].set_title("Output") + + # Plot waveform notebook.plot_feature(waveform, self.axs[1]) self.axs[1].set_title("Audio") + + # Plot reference if available if self.num_axs == 3: notebook.plot_annotation(self.reference, self.axs[2]) self.axs[2].set_title("Reference") diff --git a/src/diart/sources.py b/src/diart/sources.py index 0f5dedf7..b34d5cf3 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -5,12 +5,12 @@ import numpy as np import sounddevice as sd import torch -from diart import utils from einops import rearrange from rx.subject import Subject from torchaudio.io import StreamReader from websocket_server import WebsocketServer +from . import utils from .audio import FilePath, AudioLoader diff --git a/src/diart/utils.py b/src/diart/utils.py index e90861c7..e825ef29 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -4,9 +4,11 @@ import matplotlib.pyplot as plt import numpy as np -from diart.progress import ProgressBar from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook +from .progress import ProgressBar +from . import blocks + class Chronometer: def __init__(self, unit: Text, progress_bar: Optional[ProgressBar] = None): @@ -74,6 +76,18 @@ def get_padding_left(stream_duration: float, chunk_duration: float) -> float: return 0 +def repeat_label(label: Text): + while True: + yield label + + +def get_pipeline_class(class_name: Text) -> type: + pipeline_class = getattr(blocks, class_name, None) + msg = f"Pipeline '{class_name}' doesn't exist" + assert pipeline_class is not None, msg + return pipeline_class + + def get_padding_right(latency: float, step: float) -> float: return latency - step From 6caa4a4ab9b2e8c2ab7bc5049b1112f325a08d5b Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 17:43:51 +0200 Subject: [PATCH 02/14] Update link in setup.cfg --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 594c876e..e67e4426 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,11 +2,11 @@ name=diart version=0.7.0 author=Juan Manuel Coria -description=Speaker diarization in real time +description=Streaming speaker diarization in real-time long_description=file: README.md long_description_content_type=text/markdown keywords=speaker diarization, streaming, online, real time, rxpy -url=https://github.com/juanmc2005/StreamingSpeakerDiarization +url=https://github.com/juanmc2005/diart license=MIT classifiers= Development Status :: 4 - Beta From 0993fe85411d09f3b0bb5db709b7b02a7a56f0be Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 17:51:41 +0200 Subject: [PATCH 03/14] Update code snippets in README --- README.md | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index ef533946..57ca293a 100644 --- a/README.md +++ b/README.md @@ -110,17 +110,17 @@ See `diart.stream -h` for more options. ### From python -Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk: +Use `StreamingInference` to run a pipeline on an audio source and write the results to disk: ```python -from diart import OnlineSpeakerDiarization +from diart import SpeakerDiarization from diart.sources import MicrophoneAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference from diart.sinks import RTTMWriter -pipeline = OnlineSpeakerDiarization() +pipeline = SpeakerDiarization() mic = MicrophoneAudioSource(pipeline.config.sample_rate) -inference = RealTimeInference(pipeline, mic, do_plot=True) +inference = StreamingInference(pipeline, mic, do_plot=True) inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm")) prediction = inference() ``` @@ -129,13 +129,13 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n ## 🤖 Custom models -Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses): +Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`): ```python -from diart import OnlineSpeakerDiarization, PipelineConfig +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.models import EmbeddingModel, SegmentationModel from diart.sources import MicrophoneAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference def model_loader(): @@ -168,19 +168,19 @@ class MyEmbeddingModel(EmbeddingModel): return self.model(waveform, weights) -config = PipelineConfig( +config = SpeakerDiarizationConfig( segmentation=MySegmentationModel(), embedding=MyEmbeddingModel() ) -pipeline = OnlineSpeakerDiarization(config) +pipeline = SpeakerDiarization(config) mic = MicrophoneAudioSource(config.sample_rate) -inference = RealTimeInference(pipeline, mic) +inference = StreamingInference(pipeline, mic) prediction = inference() ``` ## 📈 Tune hyper-parameters -Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset. +Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs. ### From the command line @@ -281,7 +281,7 @@ diart.serve --host 0.0.0.0 --port 7007 diart.client microphone --host --port 7007 ``` -**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`. +**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`. See `-h` for more options. @@ -290,13 +290,13 @@ See `-h` for more options. For customized solutions, a server can also be created in python using the `WebSocketAudioSource`: ```python -from diart import OnlineSpeakerDiarization +from diart import SpeakerDiarization from diart.sources import WebSocketAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference -pipeline = OnlineSpeakerDiarization() +pipeline = SpeakerDiarization() source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007) -inference = RealTimeInference(pipeline, source) +inference = StreamingInference(pipeline, source) inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm())) prediction = inference() ``` @@ -354,14 +354,14 @@ or using the inference API: ```python from diart.inference import Benchmark, Parallelize -from diart import OnlineSpeakerDiarization, PipelineConfig +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.models import SegmentationModel benchmark = Benchmark("/wav/dir", "/rttm/dir") name = "pyannote/segmentation@Interspeech2021" segmentation = SegmentationModel.from_pyannote(name) -config = PipelineConfig( +config = SpeakerDiarizationConfig( # Set the model used in the paper segmentation=segmentation, step=0.5, @@ -370,12 +370,12 @@ config = PipelineConfig( rho_update=0.422, delta_new=1.517 ) -benchmark(OnlineSpeakerDiarization, config) +benchmark(SpeakerDiarization, config) # Run the same benchmark in parallel p_benchmark = Parallelize(benchmark, num_workers=4) if __name__ == "__main__": # Needed for multiprocessing - p_benchmark(OnlineSpeakerDiarization, config) + p_benchmark(SpeakerDiarization, config) ``` This pre-calculates model outputs in batches, so it runs a lot faster. From 95d4fae66dea06e1cbb12ac591e5f323687cd02f Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 21:18:36 +0200 Subject: [PATCH 04/14] Add minor README modifications --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 57ca293a..ae13059f 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ | - 🤖 Custom models + 🤖 Add your model | @@ -127,7 +127,7 @@ prediction = inference() For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)). -## 🤖 Custom models +## 🤖 Add your model Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`): From 569c68fa5648c9c940dae215ed557582f17a513f Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Mon, 24 Apr 2023 11:25:51 +0200 Subject: [PATCH 05/14] Rename base pipeline and config objects --- src/diart/__init__.py | 4 ++-- src/diart/blocks/__init__.py | 2 +- src/diart/blocks/base.py | 8 ++++---- src/diart/blocks/diarization.py | 4 ++-- src/diart/blocks/vad.py | 6 +++--- src/diart/inference.py | 10 +++++----- src/diart/optim.py | 2 +- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/diart/__init__.py b/src/diart/__init__.py index e29287a0..4bd51327 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1,8 +1,8 @@ from .blocks import ( SpeakerDiarization, - StreamingPipeline, + Pipeline, SpeakerDiarizationConfig, - StreamingConfig, + PipelineConfig, VoiceActivityDetection, VoiceActivityDetectionConfig, ) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index e6e8c479..15cf81d9 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -14,6 +14,6 @@ ) from .segmentation import SpeakerSegmentation from .diarization import SpeakerDiarization, SpeakerDiarizationConfig -from .base import StreamingConfig, StreamingPipeline +from .base import PipelineConfig, Pipeline from .utils import Binarize, Resample, AdjustVolume from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 28f313eb..11ef961d 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -31,7 +31,7 @@ def from_name(name: Text) -> 'HyperParameter': DeltaNew = HyperParameter("delta_new", low=0, high=2) -class StreamingConfig: +class PipelineConfig: @property def duration(self) -> float: raise NotImplementedError @@ -49,7 +49,7 @@ def sample_rate(self) -> int: raise NotImplementedError @staticmethod - def from_dict(data: Any) -> 'StreamingConfig': + def from_dict(data: Any) -> 'PipelineConfig': raise NotImplementedError def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: @@ -62,7 +62,7 @@ def optimal_block_size(self) -> int: return int(np.rint(self.step * self.sample_rate)) -class StreamingPipeline: +class Pipeline: @staticmethod def get_config_class() -> type: raise NotImplementedError @@ -76,7 +76,7 @@ def hyper_parameters() -> Sequence[HyperParameter]: raise NotImplementedError @property - def config(self) -> StreamingConfig: + def config(self) -> PipelineConfig: raise NotImplementedError def reset(self): diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index f2a25119..06658cfc 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -17,7 +17,7 @@ from .. import utils -class SpeakerDiarizationConfig(base.StreamingConfig): +class SpeakerDiarizationConfig(base.PipelineConfig): def __init__( self, segmentation: Optional[m.SegmentationModel] = None, @@ -129,7 +129,7 @@ def sample_rate(self) -> int: return self._sample_rate -class SpeakerDiarization(base.StreamingPipeline): +class SpeakerDiarization(base.Pipeline): def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): self._config = SpeakerDiarizationConfig() if config is None else config diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index def833b6..e519a9cf 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -15,7 +15,7 @@ from .. import utils -class VoiceActivityDetectionConfig(base.StreamingConfig): +class VoiceActivityDetectionConfig(base.PipelineConfig): def __init__( self, segmentation: Optional[m.SegmentationModel] = None, @@ -96,7 +96,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': ) -class VoiceActivityDetection(base.StreamingPipeline): +class VoiceActivityDetection(base.Pipeline): def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): self._config = VoiceActivityDetectionConfig() if config is None else config @@ -135,7 +135,7 @@ def hyper_parameters() -> Sequence[base.HyperParameter]: return [base.TauActive] @property - def config(self) -> base.StreamingConfig: + def config(self) -> base.PipelineConfig: return self._config def reset(self): diff --git a/src/diart/inference.py b/src/diart/inference.py index 6afda89e..f562fdd9 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -53,7 +53,7 @@ class StreamingInference: """ def __init__( self, - pipeline: blocks.StreamingPipeline, + pipeline: blocks.Pipeline, source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, @@ -289,7 +289,7 @@ def get_file_paths(self) -> List[Path]: def run_single( self, - pipeline: blocks.StreamingPipeline, + pipeline: blocks.Pipeline, filepath: Path, progress_bar: ProgressBar, ) -> Annotation: @@ -374,7 +374,7 @@ def evaluate( def __call__( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: blocks.PipelineConfig, metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files. @@ -437,7 +437,7 @@ def __init__( def run_single_job( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: blocks.PipelineConfig, filepath: Path, description: Text, ) -> Annotation: @@ -474,7 +474,7 @@ def run_single_job( def __call__( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: blocks.PipelineConfig, metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files in parallel. diff --git a/src/diart/optim.py b/src/diart/optim.py index f7a96a6e..86492627 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -23,7 +23,7 @@ def __init__( study_or_path: Union[FilePath, Study], batch_size: int = 32, hparams: Optional[Sequence[blocks.base.HyperParameter]] = None, - base_config: Optional[blocks.StreamingConfig] = None, + base_config: Optional[blocks.PipelineConfig] = None, do_kickstart_hparams: bool = True, metric: Optional[BaseMetric] = None, direction: Literal["minimize", "maximize"] = "minimize", From 061c49be41c47d38f6fdb65d50910a061d00b75c Mon Sep 17 00:00:00 2001 From: Jonny Saunders Date: Fri, 6 Oct 2023 03:24:07 -0700 Subject: [PATCH 06/14] [joss] Add Conda environment (#172) * Conda environment! * Update README with env instructions --------- Co-authored-by: Juan Coria --- README.md | 10 ++-------- environment.yml | 12 ++++++++++++ 2 files changed, 14 insertions(+), 8 deletions(-) create mode 100644 environment.yml diff --git a/README.md b/README.md index ae13059f..0ddc935f 100644 --- a/README.md +++ b/README.md @@ -64,17 +64,11 @@ 1) Create environment: ```shell -conda create -n diart python=3.8 +conda env create -f diart/environment.yml conda activate diart ``` -2) Install audio libraries: - -```shell -conda install portaudio pysoundfile ffmpeg -c conda-forge -``` - -3) Install diart: +2) Install the package: ```shell pip install diart ``` diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..f62b4274 --- /dev/null +++ b/environment.yml @@ -0,0 +1,12 @@ +name: diart +channels: + - conda-forge + - defaults +dependencies: + - python=3.8 + - portaudio=19.6.* + - pysoundfile=0.12.* + - ffmpeg[version='<4.4'] + - pip + - pip: + - . \ No newline at end of file From fa4eecdf3934fa07f7b9c8dc0d3a880209c90029 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Wed, 11 Oct 2023 16:03:39 +0200 Subject: [PATCH 07/14] Replace NotImplementedError with ABC/abstractmethod (#179) Co-authored-by: sneakers-the-rat --- src/diart/blocks/aggregation.py | 6 +++-- src/diart/blocks/base.py | 41 ++++++++++++++++++++++----------- src/diart/blocks/diarization.py | 2 +- src/diart/features.py | 10 +++++--- src/diart/mapping.py | 12 ++++++---- src/diart/models.py | 15 ++++++++---- src/diart/progress.py | 27 ++++++++++++++-------- src/diart/sources.py | 9 +++++--- 8 files changed, 81 insertions(+), 41 deletions(-) diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py index b6352a28..41b836fe 100644 --- a/src/diart/blocks/aggregation.py +++ b/src/diart/blocks/aggregation.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Optional, List import numpy as np @@ -5,7 +6,7 @@ from typing_extensions import Literal -class AggregationStrategy: +class AggregationStrategy(ABC): """Abstract class representing a strategy to aggregate overlapping buffers Parameters @@ -58,8 +59,9 @@ def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> Slidi ) return SlidingWindowFeature(aggregation, resolution) + @abstractmethod def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: - raise NotImplementedError + pass class HammingWeightedAverageStrategy(AggregationStrategy): diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 11ef961d..11d298b0 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -1,5 +1,6 @@ from typing import Any, Tuple, Sequence, Text from dataclasses import dataclass +from abc import ABC, abstractmethod import numpy as np from pyannote.core import SlidingWindowFeature @@ -31,26 +32,31 @@ def from_name(name: Text) -> 'HyperParameter': DeltaNew = HyperParameter("delta_new", low=0, high=2) -class PipelineConfig: +class PipelineConfig(ABC): @property + @abstractmethod def duration(self) -> float: - raise NotImplementedError + pass @property + @abstractmethod def step(self) -> float: - raise NotImplementedError + pass @property + @abstractmethod def latency(self) -> float: - raise NotImplementedError + pass @property + @abstractmethod def sample_rate(self) -> int: - raise NotImplementedError + pass @staticmethod + @abstractmethod def from_dict(data: Any) -> 'PipelineConfig': - raise NotImplementedError + pass def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) @@ -62,31 +68,38 @@ def optimal_block_size(self) -> int: return int(np.rint(self.step * self.sample_rate)) -class Pipeline: +class Pipeline(ABC): @staticmethod + @abstractmethod def get_config_class() -> type: - raise NotImplementedError + pass @staticmethod + @abstractmethod def suggest_metric() -> BaseMetric: - raise NotImplementedError + pass @staticmethod + @abstractmethod def hyper_parameters() -> Sequence[HyperParameter]: - raise NotImplementedError + pass @property + @abstractmethod def config(self) -> PipelineConfig: - raise NotImplementedError + pass + @abstractmethod def reset(self): - raise NotImplementedError + pass + @abstractmethod def set_timestamp_shift(self, shift: float): - raise NotImplementedError + pass + @abstractmethod def __call__( self, waveforms: Sequence[SlidingWindowFeature] ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: - raise NotImplementedError + pass diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 06658cfc..ec9a73f3 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -7,8 +7,8 @@ from pyannote.metrics.diarization import DiarizationErrorRate from typing_extensions import Literal -from .aggregation import DelayedAggregation from . import base +from .aggregation import DelayedAggregation from .clustering import OnlineSpeakerClustering from .embedding import OverlapAwareSpeakerEmbedding from .segmentation import SpeakerSegmentation diff --git a/src/diart/features.py b/src/diart/features.py index 2489027a..ffa83c8f 100644 --- a/src/diart/features.py +++ b/src/diart/features.py @@ -1,4 +1,5 @@ from typing import Union, Optional +from abc import ABC, abstractmethod import numpy as np import torch @@ -7,15 +8,18 @@ TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor] -class TemporalFeatureFormatterState: +class TemporalFeatureFormatterState(ABC): """ Represents the recorded type of a temporal feature formatter. Its job is to transform temporal features into tensors and recover the original format on other features. """ + + @abstractmethod def to_tensor(self, features: TemporalFeatures) -> torch.Tensor: - raise NotImplementedError + pass + @abstractmethod def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: """ Cast `features` to the representing type and remove batch dimension if required. @@ -28,7 +32,7 @@ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: ------- new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim) """ - raise NotImplementedError + pass class SlidingWindowFeatureFormatterState(TemporalFeatureFormatterState): diff --git a/src/diart/mapping.py b/src/diart/mapping.py index 3023da4d..847d0f78 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -1,13 +1,14 @@ from __future__ import annotations from typing import Callable, Iterable, List, Optional, Text, Tuple, Union, Dict +from abc import ABC, abstractmethod import numpy as np from pyannote.core.utils.distance import cdist from scipy.optimize import linear_sum_assignment as lsap -class MappingMatrixObjective: +class MappingMatrixObjective(ABC): def invalid_tensor(self, shape: Union[Tuple, int]) -> np.ndarray: return np.ones(shape) * self.invalid_value @@ -51,16 +52,19 @@ def invalid_value(self) -> float: return -1e10 if self.maximize else 1e10 @property + @abstractmethod def maximize(self) -> bool: - raise NotImplementedError() + pass @property + @abstractmethod def best_possible_value(self) -> float: - raise NotImplementedError() + pass @property + @abstractmethod def best_value_fn(self) -> Callable: - raise NotImplementedError() + pass class MinimizationObjective(MappingMatrixObjective): diff --git a/src/diart/models.py b/src/diart/models.py index df66e166..42056c44 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Optional, Text, Union, Callable import torch @@ -20,7 +21,7 @@ def __call__(self) -> nn.Module: return pyannote_loader.get_model(self.model_info, self.hf_token) -class LazyModel(nn.Module): +class LazyModel(nn.Module, ABC): def __init__(self, loader: Callable[[], nn.Module]): super().__init__() self.get_model = loader @@ -69,13 +70,16 @@ def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'Segme return PyannoteSegmentationModel(model, use_hf_token) @property + @abstractmethod def sample_rate(self) -> int: - raise NotImplementedError + pass @property + @abstractmethod def duration(self) -> float: - raise NotImplementedError + pass + @abstractmethod def forward(self, waveform: torch.Tensor) -> torch.Tensor: """ Forward pass of the segmentation model. @@ -88,7 +92,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: ------- speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) """ - raise NotImplementedError + pass class PyannoteSegmentationModel(SegmentationModel): @@ -132,6 +136,7 @@ def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'Embed assert _has_pyannote, "No pyannote.audio installation found" return PyannoteEmbeddingModel(model, use_hf_token) + @abstractmethod def forward( self, waveform: torch.Tensor, @@ -150,7 +155,7 @@ def forward( ------- speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) """ - raise NotImplementedError + pass class PyannoteEmbeddingModel(EmbeddingModel): diff --git a/src/diart/progress.py b/src/diart/progress.py index 240cc628..ca62e8bb 100644 --- a/src/diart/progress.py +++ b/src/diart/progress.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Optional, Text import rich @@ -5,32 +6,40 @@ from tqdm import tqdm -class ProgressBar: +class ProgressBar(ABC): + @abstractmethod def create(self, total: int, description: Optional[Text] = None, unit: Text = "it", **kwargs): - raise NotImplementedError + pass + @abstractmethod def start(self): - raise NotImplementedError + pass + @abstractmethod def update(self, n: int = 1): - raise NotImplementedError + pass + @abstractmethod def write(self, text: Text): - raise NotImplementedError + pass + @abstractmethod def stop(self): - raise NotImplementedError + pass + @abstractmethod def close(self): - raise NotImplementedError + pass @property + @abstractmethod def default_description(self) -> Text: - raise NotImplementedError + pass @property + @abstractmethod def initial_description(self) -> Optional[Text]: - raise NotImplementedError + pass def resolve_description(self, new_description: Optional[Text] = None) -> Text: if self.initial_description is None: diff --git a/src/diart/sources.py b/src/diart/sources.py index b34d5cf3..abfa8b6f 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from pathlib import Path from queue import SimpleQueue from typing import Text, Optional, AnyStr, Dict, Any, Union, Tuple @@ -14,7 +15,7 @@ from .audio import FilePath, AudioLoader -class AudioSource: +class AudioSource(ABC): """Represents a source of audio that can start streaming via the `stream` property. Parameters @@ -34,13 +35,15 @@ def duration(self) -> Optional[float]: """The duration of the stream if known. Defaults to None (unknown duration).""" return None + @abstractmethod def read(self): """Start reading the source and yielding samples through the stream.""" - raise NotImplementedError + pass + @abstractmethod def close(self): """Stop reading the source and close all open streams.""" - raise NotImplementedError + pass class FileAudioSource(AudioSource): From 410ab89bd39c0860a227a6a74944540ae7299cac Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Wed, 11 Oct 2023 16:13:00 +0200 Subject: [PATCH 08/14] Fix sample rate issues (#153) * Add automatic sample rate detection in MicrophoneAudioSource. Fix resampling crash. * Replace block_size by block_duration in audio source constructors --- README.md | 9 ++++---- src/diart/blocks/base.py | 3 --- src/diart/blocks/utils.py | 9 +++++--- src/diart/console/client.py | 5 ++--- src/diart/console/stream.py | 5 ++--- src/diart/inference.py | 12 ++++++---- src/diart/sources.py | 44 +++++++++++++++++++++++-------------- 7 files changed, 49 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 0ddc935f..caef5045 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ from diart.inference import StreamingInference from diart.sinks import RTTMWriter pipeline = SpeakerDiarization() -mic = MicrophoneAudioSource(pipeline.config.sample_rate) +mic = MicrophoneAudioSource() inference = StreamingInference(pipeline, mic, do_plot=True) inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm")) prediction = inference() @@ -167,7 +167,7 @@ config = SpeakerDiarizationConfig( embedding=MyEmbeddingModel() ) pipeline = SpeakerDiarization(config) -mic = MicrophoneAudioSource(config.sample_rate) +mic = MicrophoneAudioSource() inference = StreamingInference(pipeline, mic) prediction = inference() ``` @@ -241,12 +241,11 @@ from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation") embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding") -sample_rate = segmentation.model.sample_rate -mic = MicrophoneAudioSource(sample_rate) +mic = MicrophoneAudioSource() stream = mic.stream.pipe( # Reformat stream to 5s duration and 500ms shift - dops.rearrange_audio_stream(sample_rate=sample_rate), + dops.rearrange_audio_stream(sample_rate=segmentation.model.sample_rate), ops.map(lambda wav: (wav, segmentation(wav))), ops.starmap(embedding) ).subscribe(on_next=lambda emb: print(emb.shape)) diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 11d298b0..d3986e1b 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -64,9 +64,6 @@ def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: left = utils.get_padding_left(file_duration + right, self.duration) return left, right - def optimal_block_size(self) -> int: - return int(np.rint(self.step * self.sample_rate)) - class Pipeline(ABC): @staticmethod diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py index 02594e3d..9af5cb9d 100644 --- a/src/diart/blocks/utils.py +++ b/src/diart/blocks/utils.py @@ -69,12 +69,15 @@ class Resample: resample_rate: int Sample rate of the output """ - def __init__(self, sample_rate: int, resample_rate: int): - self.resample = T.Resample(sample_rate, resample_rate) + def __init__(self, sample_rate: int, resample_rate: int, device: Optional[torch.device] = None): + self.device = device + if self.device is None: + self.device = torch.device("cpu") + self.resample = T.Resample(sample_rate, resample_rate).to(self.device) self.formatter = TemporalFeatureFormatter() def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: - wav = self.formatter.cast(waveform) # shape (batch, samples, 1) + wav = self.formatter.cast(waveform).to(self.device) # shape (batch, samples, 1) with torch.no_grad(): resampled_wav = self.resample(wav.transpose(-1, -2)).transpose(-1, -2) return self.formatter.restore_type(resampled_wav) diff --git a/src/diart/console/client.py b/src/diart/console/client.py index db4915fa..d1896ec6 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -13,13 +13,12 @@ def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int): # Create audio source - block_size = int(np.rint(step * sample_rate)) source_components = source.split(":") if source_components[0] != "microphone": - audio_source = src.FileAudioSource(source, sample_rate) + audio_source = src.FileAudioSource(source, sample_rate, block_duration=step) else: device = int(source_components[1]) if len(source_components) > 1 else None - audio_source = src.MicrophoneAudioSource(sample_rate, block_size, device) + audio_source = src.MicrophoneAudioSource(step, device) # Encode audio, then send through websocket audio_source.stream.pipe( diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index e0c670c5..fd7df5eb 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -40,18 +40,17 @@ def run(): pipeline = pipeline_class(config) # Manage audio source - block_size = config.optimal_block_size() source_components = args.source.split(":") if source_components[0] != "microphone": args.source = Path(args.source).expanduser() args.output = args.source.parent if args.output is None else Path(args.output) padding = config.get_file_padding(args.source) - audio_source = src.FileAudioSource(args.source, config.sample_rate, padding, block_size) + audio_source = src.FileAudioSource(args.source, config.sample_rate, padding, config.step) pipeline.set_timestamp_shift(-padding[0]) else: args.output = Path("~/").expanduser() if args.output is None else Path(args.output) device = int(source_components[1]) if len(source_components) > 1 else None - audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device) + audio_source = src.MicrophoneAudioSource(config.step, device) # Run online inference inference = StreamingInference( diff --git a/src/diart/inference.py b/src/diart/inference.py index f562fdd9..99e5c757 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -97,18 +97,22 @@ def __init__( self.stream = self.source.stream + # Rearrange stream to form sliding windows + self.stream = self.stream.pipe( + dops.rearrange_audio_stream(chunk_duration, step_duration, source.sample_rate), + ) + # Dynamic resampling if the audio source isn't compatible if sample_rate != self.source.sample_rate: msg = f"Audio source has sample rate {self.source.sample_rate}, " \ f"but pipeline's is {sample_rate}. Will resample." logging.warning(msg) self.stream = self.stream.pipe( - ops.map(blocks.Resample(self.source.sample_rate, sample_rate)) + ops.map(blocks.Resample(self.source.sample_rate, sample_rate, self.pipeline.config.device)) ) - # Add rx operators to manage the inputs and outputs of the pipeline + # Form batches self.stream = self.stream.pipe( - dops.rearrange_audio_stream(chunk_duration, step_duration, sample_rate), ops.buffer_with_count(count=self.batch_size), ) @@ -316,7 +320,7 @@ def run_single( filepath, pipeline.config.sample_rate, padding, - pipeline.config.optimal_block_size(), + pipeline.config.step, ) pipeline.set_timestamp_shift(-padding[0]) inference = StreamingInference( diff --git a/src/diart/sources.py b/src/diart/sources.py index abfa8b6f..5ae6c0eb 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -58,23 +58,23 @@ class FileAudioSource(AudioSource): padding: (float, float) Left and right padding to add to the file (in seconds). Defaults to (0, 0). - block_size: int - Number of samples per chunk emitted. - Defaults to 1000. + block_duration: int + Duration of each emitted chunk in seconds. + Defaults to 0.5 seconds. """ def __init__( self, file: FilePath, sample_rate: int, padding: Tuple[float, float] = (0, 0), - block_size: int = 1000, + block_duration: float = 0.5, ): super().__init__(Path(file).stem, sample_rate) self.loader = AudioLoader(self.sample_rate, mono=True) self._duration = self.loader.get_duration(file) self.file = file self.resolution = 1 / self.sample_rate - self.block_size = block_size + self.block_size = int(np.rint(block_duration * self.sample_rate)) self.padding_start, self.padding_end = padding self.is_closed = False @@ -134,11 +134,9 @@ class MicrophoneAudioSource(AudioSource): Parameters ---------- - sample_rate: int - Sample rate for the emitted audio chunks. - block_size: int - Number of samples per chunk emitted. - Defaults to 1000. + block_duration: int + Duration of each emitted chunk in seconds. + Defaults to 0.5 seconds. device: int | str | (int, str) | None Device identifier compatible for the sounddevice stream. If None, use the default device. @@ -147,15 +145,27 @@ class MicrophoneAudioSource(AudioSource): def __init__( self, - sample_rate: int, - block_size: int = 1000, + block_duration: float = 0.5, device: Optional[Union[int, Text, Tuple[int, Text]]] = None, ): - super().__init__("live_recording", sample_rate) - self.block_size = block_size + # Use the lowest supported sample rate + sample_rates = [16000, 32000, 44100, 48000] + best_sample_rate = None + for sr in sample_rates: + try: + sd.check_input_settings(device=device, samplerate=sr) + except Exception: + pass + else: + best_sample_rate = sr + break + super().__init__(f"input_device:{device}", best_sample_rate) + + # Determine block size in samples and create input stream + self.block_size = int(np.rint(block_duration * self.sample_rate)) self._mic_stream = sd.InputStream( channels=1, - samplerate=sample_rate, + samplerate=self.sample_rate, latency=0, blocksize=self.block_size, callback=self._read_callback, @@ -261,10 +271,10 @@ def __init__( sample_rate: int, streamer: StreamReader, stream_index: Optional[int] = None, - block_size: int = 1000, + block_duration: float = 0.5, ): super().__init__(uri, sample_rate) - self.block_size = block_size + self.block_size = int(np.rint(block_duration * self.sample_rate)) self._streamer = streamer self._streamer.add_basic_audio_stream( frames_per_chunk=self.block_size, From 25d2196d2fd7340cbec5158bd97c77b6f73880e3 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Wed, 11 Oct 2023 16:34:48 +0200 Subject: [PATCH 09/14] Fix torchaudio version incompatibility (#181) --- requirements.txt | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 50662023..e0d93213 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ tqdm>=4.64.0 pandas>=1.4.2 torch>=1.12.1 torchvision>=0.14.0 -torchaudio>=0.12.1,<1.0 +torchaudio>=2.0.2 pyannote.audio>=2.1.1 pyannote.core>=4.5 pyannote.database>=4.1.1 diff --git a/setup.cfg b/setup.cfg index e67e4426..c2e536d8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ install_requires= pandas>=1.4.2 torch>=1.12.1 torchvision>=0.14.0 - torchaudio>=0.12.1,<1.0 + torchaudio>=2.0.2 pyannote.audio>=2.1.1 pyannote.core>=4.5 pyannote.database>=4.1.1 From 1d1d826c2501a94c9e6eb9268130e603d9272fdb Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Wed, 11 Oct 2023 17:00:46 +0200 Subject: [PATCH 10/14] Add special audio source for Apple devices (#182) * Blacken entire code base * Add AppleDeviceAudioSource --- src/diart/argdoc.py | 1 + src/diart/blocks/aggregation.py | 56 ++++++++++----- src/diart/blocks/base.py | 7 +- src/diart/blocks/clustering.py | 34 ++++----- src/diart/blocks/diarization.py | 31 ++++++--- src/diart/blocks/embedding.py | 22 ++++-- src/diart/blocks/segmentation.py | 9 ++- src/diart/blocks/utils.py | 13 +++- src/diart/blocks/vad.py | 26 +++++-- src/diart/console/benchmark.py | 107 +++++++++++++++++++++------- src/diart/console/client.py | 33 ++++++--- src/diart/console/serve.py | 83 ++++++++++++++++------ src/diart/console/stream.py | 110 ++++++++++++++++++++++------- src/diart/console/tune.py | 115 +++++++++++++++++++++++-------- src/diart/features.py | 5 +- src/diart/inference.py | 39 ++++++++--- src/diart/models.py | 15 ++-- src/diart/operators.py | 48 ++++++++----- src/diart/optim.py | 17 +++-- src/diart/progress.py | 30 ++++++-- src/diart/sinks.py | 17 +++-- src/diart/sources.py | 26 ++++++- src/diart/utils.py | 2 + 23 files changed, 620 insertions(+), 226 deletions(-) diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py index d16df41e..e89caa28 100644 --- a/src/diart/argdoc.py +++ b/src/diart/argdoc.py @@ -1,5 +1,6 @@ SEGMENTATION = "Segmentation model name from pyannote" EMBEDDING = "Embedding model name from pyannote" +DURATION = "Chunk duration (in seconds)" STEP = "Sliding window step (in seconds)" LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION" TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1" diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py index 41b836fe..aa5e6a1e 100644 --- a/src/diart/blocks/aggregation.py +++ b/src/diart/blocks/aggregation.py @@ -18,14 +18,18 @@ class AggregationStrategy(ABC): """ def __init__(self, cropping_mode: Literal["strict", "loose", "center"] = "loose"): - assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`" + assert cropping_mode in [ + "strict", + "loose", + "center", + ], f"Invalid cropping mode `{cropping_mode}`" self.cropping_mode = cropping_mode @staticmethod def build( name: Literal["mean", "hamming", "first"], - cropping_mode: Literal["strict", "loose", "center"] = "loose" - ) -> 'AggregationStrategy': + cropping_mode: Literal["strict", "loose", "center"] = "loose", + ) -> "AggregationStrategy": """Build an AggregationStrategy instance based on its name""" assert name in ("mean", "hamming", "first") if name == "mean": @@ -35,7 +39,9 @@ def build( else: return FirstOnlyStrategy(cropping_mode) - def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> SlidingWindowFeature: + def __call__( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> SlidingWindowFeature: """Aggregate chunks over a specific region. Parameters @@ -53,21 +59,23 @@ def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> Slidi aggregation = self.aggregate(buffers, focus) resolution = focus.duration / aggregation.shape[0] resolution = SlidingWindow( - start=focus.start, - duration=resolution, - step=resolution + start=focus.start, duration=resolution, step=resolution ) return SlidingWindowFeature(aggregation, resolution) @abstractmethod - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + def aggregate( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> np.ndarray: pass class HammingWeightedAverageStrategy(AggregationStrategy): """Compute the average weighted by the corresponding Hamming-window aligned to each buffer""" - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + def aggregate( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> np.ndarray: num_frames, num_speakers = buffers[0].data.shape hamming, intersection = [], [] for buffer in buffers: @@ -87,19 +95,25 @@ def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.n class AverageStrategy(AggregationStrategy): """Compute a simple average over the focus region""" - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + def aggregate( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> np.ndarray: # Stack all overlapping regions - intersection = np.stack([ - buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration) - for buffer in buffers - ]) + intersection = np.stack( + [ + buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration) + for buffer in buffers + ] + ) return np.mean(intersection, axis=0) class FirstOnlyStrategy(AggregationStrategy): """Instead of aggregating, keep the first focus region in the buffer list""" - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + def aggregate( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> np.ndarray: return buffers[0].crop(focus, mode=self.cropping_mode, fixed=focus.duration) @@ -151,12 +165,16 @@ def __init__( step: float, latency: Optional[float] = None, strategy: Literal["mean", "hamming", "first"] = "hamming", - cropping_mode: Literal["strict", "loose", "center"] = "loose" + cropping_mode: Literal["strict", "loose", "center"] = "loose", ): self.step = step self.latency = latency self.strategy = strategy - assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`" + assert cropping_mode in [ + "strict", + "loose", + "center", + ], f"Invalid cropping mode `{cropping_mode}`" self.cropping_mode = cropping_mode if self.latency is None: @@ -171,7 +189,7 @@ def _prepend( self, output_window: SlidingWindowFeature, output_region: Segment, - buffers: List[SlidingWindowFeature] + buffers: List[SlidingWindowFeature], ): # FIXME instead of prepending the output of the first chunk, # add padding of `chunk_duration - latency` seconds at the @@ -189,7 +207,7 @@ def _prepend( resolution = output_region.end / first_output.shape[0] output_window = SlidingWindowFeature( first_output, - SlidingWindow(start=0, duration=resolution, step=resolution) + SlidingWindow(start=0, duration=resolution, step=resolution), ) return output_window diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index d3986e1b..6536a3f7 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -17,7 +17,7 @@ class HyperParameter: high: float @staticmethod - def from_name(name: Text) -> 'HyperParameter': + def from_name(name: Text) -> "HyperParameter": if name == "tau_active": return TauActive if name == "rho_update": @@ -55,7 +55,7 @@ def sample_rate(self) -> int: @staticmethod @abstractmethod - def from_dict(data: Any) -> 'PipelineConfig': + def from_dict(data: Any) -> "PipelineConfig": pass def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: @@ -96,7 +96,6 @@ def set_timestamp_shift(self, shift: float): @abstractmethod def __call__( - self, - waveforms: Sequence[SlidingWindowFeature] + self, waveforms: Sequence[SlidingWindowFeature] ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: pass diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py index 882001b9..b7217c0a 100644 --- a/src/diart/blocks/clustering.py +++ b/src/diart/blocks/clustering.py @@ -27,13 +27,14 @@ class OnlineSpeakerClustering: max_speakers: int Maximum number of global speakers to track through a conversation. Defaults to 20. """ + def __init__( self, tau_active: float, rho_update: float, delta_new: float, metric: Optional[str] = "cosine", - max_speakers: int = 20 + max_speakers: int = 20, ): self.tau_active = tau_active self.rho_update = rho_update @@ -116,9 +117,7 @@ def add_center(self, embedding: np.ndarray) -> int: return center def identify( - self, - segmentation: SlidingWindowFeature, - embeddings: torch.Tensor + self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor ) -> SpeakerMap: """Identify the centroids to which the input speaker embeddings belong. @@ -135,15 +134,18 @@ def identify( A mapping from local speakers to global speakers. """ embeddings = embeddings.detach().cpu().numpy() - active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0] - long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0] + active_speakers = np.where( + np.max(segmentation.data, axis=0) >= self.tau_active + )[0] + long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[ + 0 + ] num_local_speakers = segmentation.data.shape[1] if self.centers is None: self.init_centers(embeddings.shape[1]) assignments = [ - (spk, self.add_center(embeddings[spk])) - for spk in active_speakers + (spk, self.add_center(embeddings[spk])) for spk in active_speakers ] return SpeakerMapBuilder.hard_map( shape=(num_local_speakers, self.max_speakers), @@ -154,18 +156,16 @@ def identify( # Obtain a mapping based on distances between embeddings and centers dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric) # Remove any assignments containing invalid speakers - inactive_speakers = np.array([ - spk for spk in range(num_local_speakers) - if spk not in active_speakers - ]) + inactive_speakers = np.array( + [spk for spk in range(num_local_speakers) if spk not in active_speakers] + ) dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers) # Keep assignments under the distance threshold valid_map = dist_map.unmap_threshold(self.delta_new) # Some speakers might be unidentified missed_speakers = [ - s for s in active_speakers - if not valid_map.is_source_speaker_mapped(s) + s for s in active_speakers if not valid_map.is_source_speaker_mapped(s) ] # Add assignments to new centers if possible @@ -205,8 +205,10 @@ def identify( return valid_map - def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) -> SlidingWindowFeature: + def __call__( + self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor + ) -> SlidingWindowFeature: return SlidingWindowFeature( self.identify(segmentation, embeddings).apply(segmentation.data), - segmentation.sliding_window + segmentation.sliding_window, ) diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index ec9a73f3..3cf4e333 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -37,7 +37,9 @@ def __init__( # Default segmentation model is pyannote/segmentation self.segmentation = segmentation if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + self.segmentation = m.SegmentationModel.from_pyannote( + "pyannote/segmentation" + ) self._duration = duration self._sample_rate: Optional[int] = None @@ -67,7 +69,7 @@ def __init__( self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @staticmethod - def from_dict(data: Any) -> 'SpeakerDiarizationConfig': + def from_dict(data: Any) -> "SpeakerDiarizationConfig": # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None device = utils.get(data, "device", None) if device is None: @@ -136,9 +138,15 @@ def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" assert self._config.step <= self._config.latency <= self._config.duration, msg - self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device) + self.segmentation = SpeakerSegmentation( + self._config.segmentation, self._config.device + ) self.embedding = OverlapAwareSpeakerEmbedding( - self._config.embedding, self._config.gamma, self._config.beta, norm=1, device=self._config.device + self._config.embedding, + self._config.gamma, + self._config.beta, + norm=1, + device=self._config.device, ) self.pred_aggregation = DelayedAggregation( self._config.step, @@ -191,8 +199,7 @@ def reset(self): self.chunk_buffer, self.pred_buffer = [], [] def __call__( - self, - waveforms: Sequence[SlidingWindowFeature] + self, waveforms: Sequence[SlidingWindowFeature] ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" @@ -201,13 +208,17 @@ def __call__( # Create batch from chunk sequence, shape (batch, samples, channels) batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) - expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) + expected_num_samples = int( + np.rint(self.config.duration * self.config.sample_rate) + ) msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" assert batch.shape[1] == expected_num_samples, msg # Extract segmentation and embeddings segmentations = self.segmentation(batch) # shape (batch, frames, speakers) - embeddings = self.embedding(batch, segmentations) # shape (batch, speakers, emb_dim) + embeddings = self.embedding( + batch, segmentations + ) # shape (batch, speakers, emb_dim) seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] @@ -236,7 +247,9 @@ def __call__( # Shift prediction timestamps if required if self.timestamp_shift != 0: shifted_agg_prediction = Annotation(agg_prediction.uri) - for segment, track, speaker in agg_prediction.itertracks(yield_label=True): + for segment, track, speaker in agg_prediction.itertracks( + yield_label=True + ): new_segment = Segment( segment.start + self.timestamp_shift, segment.end + self.timestamp_shift, diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py index 7aa31c05..5cd7c39e 100644 --- a/src/diart/blocks/embedding.py +++ b/src/diart/blocks/embedding.py @@ -22,12 +22,14 @@ def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None) def from_pyannote( model, use_hf_token: Union[Text, bool, None] = True, - device: Optional[torch.device] = None - ) -> 'SpeakerEmbedding': + device: Optional[torch.device] = None, + ) -> "SpeakerEmbedding": emb_model = EmbeddingModel.from_pyannote(model, use_hf_token) return SpeakerEmbedding(emb_model, device) - def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor: + def __call__( + self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None + ) -> torch.Tensor: """ Calculate speaker embeddings of input audio. If weights are given, calculate many speaker embeddings from the same waveform. @@ -58,7 +60,7 @@ def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeature self.model(inputs, weights), "(batch spk) feat -> batch spk feat", batch=batch_size, - spk=num_speakers + spk=num_speakers, ) else: output = self.model(inputs) @@ -76,6 +78,7 @@ class OverlappedSpeechPenalty: Temperature parameter (actually 1/beta) to lower joint speaker activations. Defaults to 10. """ + def __init__(self, gamma: float = 3, beta: float = 10): self.gamma = gamma self.beta = beta @@ -106,7 +109,11 @@ def __call__(self, embeddings: torch.Tensor) -> torch.Tensor: batch_size2, num_speakers2, _ = embeddings.shape assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2 with torch.no_grad(): - norm_embs = self.norm * embeddings / torch.norm(embeddings, p=2, dim=-1, keepdim=True) + norm_embs = ( + self.norm + * embeddings + / torch.norm(embeddings, p=2, dim=-1, keepdim=True) + ) return norm_embs @@ -131,6 +138,7 @@ class OverlapAwareSpeakerEmbedding: The device on which to run the embedding model. Defaults to GPU if available or CPU if not. """ + def __init__( self, model: EmbeddingModel, @@ -155,5 +163,7 @@ def from_pyannote( model = EmbeddingModel.from_pyannote(model, use_hf_token) return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device) - def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor: + def __call__( + self, waveform: TemporalFeatures, segmentation: TemporalFeatures + ) -> torch.Tensor: return self.normalize(self.embedding(waveform, self.osp(segmentation))) diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py index 8fda3ffc..e946c748 100644 --- a/src/diart/blocks/segmentation.py +++ b/src/diart/blocks/segmentation.py @@ -21,8 +21,8 @@ def __init__(self, model: SegmentationModel, device: Optional[torch.device] = No def from_pyannote( model, use_hf_token: Union[Text, bool, None] = True, - device: Optional[torch.device] = None - ) -> 'SpeakerSegmentation': + device: Optional[torch.device] = None, + ) -> "SpeakerSegmentation": seg_model = SegmentationModel.from_pyannote(model, use_hf_token) return SpeakerSegmentation(seg_model, device) @@ -40,6 +40,9 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: The batch dimension is omitted if waveform is a `SlidingWindowFeature`. """ with torch.no_grad(): - wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample") + wave = rearrange( + self.formatter.cast(waveform), + "batch sample channel -> batch channel sample", + ) output = self.model(wave.to(self.device)).cpu() return self.formatter.restore_type(output) diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py index 9af5cb9d..9c0afc3e 100644 --- a/src/diart/blocks/utils.py +++ b/src/diart/blocks/utils.py @@ -69,7 +69,13 @@ class Resample: resample_rate: int Sample rate of the output """ - def __init__(self, sample_rate: int, resample_rate: int, device: Optional[torch.device] = None): + + def __init__( + self, + sample_rate: int, + resample_rate: int, + device: Optional[torch.device] = None, + ): self.device = device if self.device is None: self.device = torch.device("cpu") @@ -93,6 +99,7 @@ class AdjustVolume: volume_in_db: float Target volume in dB. """ + def __init__(self, volume_in_db: float): self.target_db = volume_in_db self.formatter = TemporalFeatureFormatter() @@ -111,7 +118,9 @@ def get_volumes(waveforms: torch.Tensor) -> torch.Tensor: volumes: torch.Tensor Audio chunk volumes per channel. Shape (batch, 1, channels) """ - return 10 * torch.log10(torch.mean(torch.abs(waveforms) ** 2, dim=1, keepdim=True)) + return 10 * torch.log10( + torch.mean(torch.abs(waveforms) ** 2, dim=1, keepdim=True) + ) def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: wav = self.formatter.cast(waveform) # shape (batch, samples, channels) diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index e519a9cf..04fe5608 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -2,7 +2,13 @@ import numpy as np import torch -from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment +from pyannote.core import ( + Annotation, + Timeline, + SlidingWindowFeature, + SlidingWindow, + Segment, +) from pyannote.metrics.base import BaseMetric from pyannote.metrics.detection import DetectionErrorRate from typing_extensions import Literal @@ -29,7 +35,9 @@ def __init__( # Default segmentation model is pyannote/segmentation self.segmentation = segmentation if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + self.segmentation = m.SegmentationModel.from_pyannote( + "pyannote/segmentation" + ) self._duration = duration self._step = step @@ -70,7 +78,7 @@ def sample_rate(self) -> int: return self._sample_rate @staticmethod - def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': + def from_dict(data: Any) -> "VoiceActivityDetectionConfig": # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None device = utils.get(data, "device", None) if device is None: @@ -103,7 +111,9 @@ def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" assert self._config.step <= self._config.latency <= self._config.duration, msg - self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device) + self.segmentation = SpeakerSegmentation( + self._config.segmentation, self._config.device + ) self.pred_aggregation = DelayedAggregation( self._config.step, self._config.latency, @@ -156,13 +166,17 @@ def __call__( # Create batch from chunk sequence, shape (batch, samples, channels) batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) - expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) + expected_num_samples = int( + np.rint(self.config.duration * self.config.sample_rate) + ) msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" assert batch.shape[1] == expected_num_samples, msg # Extract segmentation segmentations = self.segmentation(batch) # shape (batch, frames, speakers) - voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[0] # shape (batch, frames, 1) + voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[ + 0 + ] # shape (batch, frames, 1) seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index 27d524c5..70b4c3d9 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -9,31 +9,88 @@ def run(): parser = argparse.ArgumentParser() - parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") - parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, - help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") - parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, - help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") - parser.add_argument("--embedding", default="pyannote/embedding", type=str, - help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") - parser.add_argument("--reference", type=Path, - help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") - parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") - parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3") - parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1") - parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") - parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") - parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") - parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") - parser.add_argument("--num-workers", default=0, type=int, - help=f"{argdoc.NUM_WORKERS}. Defaults to 0 (no parallelism)") - parser.add_argument("--cpu", dest="cpu", action="store_true", - help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing") - parser.add_argument("--hf-token", default="true", type=str, - help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") + parser.add_argument( + "root", + type=Path, + help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)", + ) + parser.add_argument( + "--pipeline", + default="SpeakerDiarization", + type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", + ) + parser.add_argument( + "--segmentation", + default="pyannote/segmentation", + type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", + ) + parser.add_argument( + "--embedding", + default="pyannote/embedding", + type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", + ) + parser.add_argument( + "--reference", + type=Path, + help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files", + ) + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" + ) + parser.add_argument( + "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + ) + parser.add_argument( + "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + ) + parser.add_argument( + "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + ) + parser.add_argument( + "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" + ) + parser.add_argument( + "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" + ) + parser.add_argument( + "--max-speakers", + default=20, + type=int, + help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", + ) + parser.add_argument( + "--batch-size", + default=32, + type=int, + help=f"{argdoc.BATCH_SIZE}. Defaults to 32", + ) + parser.add_argument( + "--num-workers", + default=0, + type=int, + help=f"{argdoc.NUM_WORKERS}. Defaults to 0 (no parallelism)", + ) + parser.add_argument( + "--cpu", + dest="cpu", + action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", + ) + parser.add_argument( + "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing" + ) + parser.add_argument( + "--hf-token", + default="true", + type=str, + help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", + ) args = parser.parse_args() pipeline_class = utils.get_pipeline_class(args.pipeline) diff --git a/src/diart/console/client.py b/src/diart/console/client.py index d1896ec6..b656298a 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -21,9 +21,7 @@ def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int): audio_source = src.MicrophoneAudioSource(step, device) # Encode audio, then send through websocket - audio_source.stream.pipe( - ops.map(utils.encode_audio) - ).subscribe_(ws.send) + audio_source.stream.pipe(ops.map(utils.encode_audio)).subscribe_(ws.send) # Start reading audio audio_source.read() @@ -40,18 +38,37 @@ def receive_audio(ws: WebSocket, output: Optional[Path]): def run(): parser = argparse.ArgumentParser() - parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'") + parser.add_argument( + "source", + type=str, + help="Path to an audio file | 'microphone' | 'microphone:'", + ) parser.add_argument("--host", required=True, type=str, help="Server host") parser.add_argument("--port", required=True, type=int, help="Server port") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("-sr", "--sample-rate", default=16000, type=int, help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000") - parser.add_argument("-o", "--output-file", type=Path, help="Output RTTM file. Defaults to no writing") + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "-sr", + "--sample-rate", + default=16000, + type=int, + help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000", + ) + parser.add_argument( + "-o", + "--output-file", + type=Path, + help="Output RTTM file. Defaults to no writing", + ) args = parser.parse_args() # Run websocket client ws = WebSocket() ws.connect(f"ws://{args.host}:{args.port}") - sender = Thread(target=send_audio, args=[ws, args.source, args.step, args.sample_rate]) + sender = Thread( + target=send_audio, args=[ws, args.source, args.step, args.sample_rate] + ) receiver = Thread(target=receive_audio, args=[ws, args.output_file]) sender.start() receiver.start() diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index 46bb9328..d8c059c3 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -12,25 +12,66 @@ def run(): parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host") parser.add_argument("--port", default=7007, type=int, help="Server port") - parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, - help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") - parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, - help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") - parser.add_argument("--embedding", default="pyannote/embedding", type=str, - help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") - parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") - parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3") - parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1") - parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") - parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") - parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") - parser.add_argument("--cpu", dest="cpu", action="store_true", - help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing") - parser.add_argument("--hf-token", default="true", type=str, - help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") + parser.add_argument( + "--pipeline", + default="SpeakerDiarization", + type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", + ) + parser.add_argument( + "--segmentation", + default="pyannote/segmentation", + type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", + ) + parser.add_argument( + "--embedding", + default="pyannote/embedding", + type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", + ) + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" + ) + parser.add_argument( + "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + ) + parser.add_argument( + "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + ) + parser.add_argument( + "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + ) + parser.add_argument( + "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" + ) + parser.add_argument( + "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" + ) + parser.add_argument( + "--max-speakers", + default=20, + type=int, + help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", + ) + parser.add_argument( + "--cpu", + dest="cpu", + action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", + ) + parser.add_argument( + "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing" + ) + parser.add_argument( + "--hf-token", + default="true", + type=str, + help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", + ) args = parser.parse_args() # Resolve pipeline @@ -53,7 +94,9 @@ def run(): # Write to disk if required if args.output is not None: - inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")) + inference.attach_observers( + RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm") + ) # Send back responses as RTTM text lines inference.attach_hooks(lambda ann_wav: audio_source.send(ann_wav[0].to_rttm())) diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index fd7df5eb..f7d96360 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -10,28 +10,82 @@ def run(): parser = argparse.ArgumentParser() - parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'") - parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, - help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") - parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, - help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") - parser.add_argument("--embedding", default="pyannote/embedding", type=str, - help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") - parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") - parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3") - parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1") - parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") - parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") - parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") - parser.add_argument("--no-plot", dest="no_plot", action="store_true", help="Skip plotting for faster inference") - parser.add_argument("--cpu", dest="cpu", action="store_true", - help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--output", type=str, - help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file") - parser.add_argument("--hf-token", default="true", type=str, - help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") + parser.add_argument( + "source", + type=str, + help="Path to an audio file | 'microphone' | 'microphone:'", + ) + parser.add_argument( + "--pipeline", + default="SpeakerDiarization", + type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", + ) + parser.add_argument( + "--segmentation", + default="pyannote/segmentation", + type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", + ) + parser.add_argument( + "--embedding", + default="pyannote/embedding", + type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", + ) + parser.add_argument( + "--duration", type=float, help=f"{argdoc.DURATION}. Defaults to training segmentation duration" + ) + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" + ) + parser.add_argument( + "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + ) + parser.add_argument( + "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + ) + parser.add_argument( + "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + ) + parser.add_argument( + "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" + ) + parser.add_argument( + "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" + ) + parser.add_argument( + "--max-speakers", + default=20, + type=int, + help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", + ) + parser.add_argument( + "--no-plot", + dest="no_plot", + action="store_true", + help="Skip plotting for faster inference", + ) + parser.add_argument( + "--cpu", + dest="cpu", + action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", + ) + parser.add_argument( + "--output", + type=str, + help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file", + ) + parser.add_argument( + "--hf-token", + default="true", + type=str, + help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", + ) args = parser.parse_args() # Resolve pipeline @@ -45,10 +99,14 @@ def run(): args.source = Path(args.source).expanduser() args.output = args.source.parent if args.output is None else Path(args.output) padding = config.get_file_padding(args.source) - audio_source = src.FileAudioSource(args.source, config.sample_rate, padding, config.step) + audio_source = src.FileAudioSource( + args.source, config.sample_rate, padding, config.step + ) pipeline.set_timestamp_shift(-padding[0]) else: - args.output = Path("~/").expanduser() if args.output is None else Path(args.output) + args.output = ( + Path("~/").expanduser() if args.output is None else Path(args.output) + ) device = int(source_components[1]) if len(source_components) > 1 else None audio_source = src.MicrophoneAudioSource(config.step, device) @@ -61,7 +119,9 @@ def run(): do_plot=not args.no_plot, show_progress=True, ) - inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")) + inference.attach_observers( + RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm") + ) inference() diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index a1f1b63a..4c969efa 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -11,34 +11,95 @@ def run(): parser = argparse.ArgumentParser() - parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") - parser.add_argument("--reference", required=True, type=str, - help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") - parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, - help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") - parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, - help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") - parser.add_argument("--embedding", default="pyannote/embedding", type=str, - help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") - parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") - parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3") - parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1") - parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") - parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") - parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") - parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") - parser.add_argument("--cpu", dest="cpu", action="store_true", - help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"), - help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new") - parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials") - parser.add_argument("--storage", type=str, - help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name") + parser.add_argument( + "root", + type=str, + help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)", + ) + parser.add_argument( + "--reference", + required=True, + type=str, + help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files", + ) + parser.add_argument( + "--pipeline", + default="SpeakerDiarization", + type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", + ) + parser.add_argument( + "--segmentation", + default="pyannote/segmentation", + type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", + ) + parser.add_argument( + "--embedding", + default="pyannote/embedding", + type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", + ) + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" + ) + parser.add_argument( + "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + ) + parser.add_argument( + "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + ) + parser.add_argument( + "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + ) + parser.add_argument( + "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" + ) + parser.add_argument( + "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" + ) + parser.add_argument( + "--max-speakers", + default=20, + type=int, + help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", + ) + parser.add_argument( + "--batch-size", + default=32, + type=int, + help=f"{argdoc.BATCH_SIZE}. Defaults to 32", + ) + parser.add_argument( + "--cpu", + dest="cpu", + action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", + ) + parser.add_argument( + "--hparams", + nargs="+", + default=("tau_active", "rho_update", "delta_new"), + help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new", + ) + parser.add_argument( + "--num-iter", default=100, type=int, help="Number of optimization trials" + ) + parser.add_argument( + "--storage", + type=str, + help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name", + ) parser.add_argument("--output", type=str, help="Working directory") - parser.add_argument("--hf-token", default="true", type=str, - help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") + parser.add_argument( + "--hf-token", + default="true", + type=str, + help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", + ) args = parser.parse_args() # Retrieve pipeline class diff --git a/src/diart/features.py b/src/diart/features.py index ffa83c8f..2d5df672 100644 --- a/src/diart/features.py +++ b/src/diart/features.py @@ -52,7 +52,9 @@ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: # Calculate resolution resolution = self.duration / num_frames # Temporal shift to keep track of current start time - resolution = SlidingWindow(start=self._cur_start_time, duration=resolution, step=resolution) + resolution = SlidingWindow( + start=self._cur_start_time, duration=resolution, step=resolution + ) return SlidingWindowFeature(features.squeeze(dim=0).cpu().numpy(), resolution) @@ -78,6 +80,7 @@ class TemporalFeatureFormatter: When casting temporal features as torch.Tensor, it remembers its type and format so it can lately restore it on other temporal features. """ + def __init__(self): self.state: Optional[TemporalFeatureFormatterState] = None diff --git a/src/diart/inference.py b/src/diart/inference.py index 99e5c757..3eb72930 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -51,6 +51,7 @@ class StreamingInference: If description is not provided, set to 'Streaming '. Defaults to RichProgressBar(). """ + def __init__( self, pipeline: blocks.Pipeline, @@ -89,7 +90,7 @@ def __init__( self._pbar.create( total=self.num_chunks, description=f"Streaming {self.source.uri}", - unit=self.unit + unit=self.unit, ) # Initialize chronometer for profiling @@ -99,16 +100,26 @@ def __init__( # Rearrange stream to form sliding windows self.stream = self.stream.pipe( - dops.rearrange_audio_stream(chunk_duration, step_duration, source.sample_rate), + dops.rearrange_audio_stream( + chunk_duration, step_duration, source.sample_rate + ), ) # Dynamic resampling if the audio source isn't compatible if sample_rate != self.source.sample_rate: - msg = f"Audio source has sample rate {self.source.sample_rate}, " \ - f"but pipeline's is {sample_rate}. Will resample." + msg = ( + f"Audio source has sample rate {self.source.sample_rate}, " + f"but pipeline's is {sample_rate}. Will resample." + ) logging.warning(msg) self.stream = self.stream.pipe( - ops.map(blocks.Resample(self.source.sample_rate, sample_rate, self.pipeline.config.device)) + ops.map( + blocks.Resample( + self.source.sample_rate, + sample_rate, + self.pipeline.config.device, + ) + ) ) # Form batches @@ -145,7 +156,9 @@ def _close_chronometer(self): self._chrono.stop(do_count=False) self._chrono.report() - def attach_hooks(self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None]): + def attach_hooks( + self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None] + ): """Attach hooks to the pipeline. Parameters @@ -251,6 +264,7 @@ class Benchmark: The performance between this two modes does not differ. Defaults to 32. """ + def __init__( self, speech_path: Union[Text, Path], @@ -430,6 +444,7 @@ class Parallelize: Number of parallel workers. Defaults to 0 (no parallelism). """ + def __init__( self, benchmark: Benchmark, @@ -466,12 +481,14 @@ def run_single_job( Pipeline prediction for the given file. """ # The process ID inside the pool determines the position of the progress bar - idx_process = int(current_process().name.split('-')[1]) - 1 + idx_process = int(current_process().name.split("-")[1]) - 1 # TODO share models across processes # Instantiate a pipeline with the config pipeline = pipeline_class(config) # Create the progress bar for this job - progress = TQDMProgressBar(description, leave=False, position=idx_process, do_close=True) + progress = TQDMProgressBar( + description, leave=False, position=idx_process, do_close=True + ) # Run the pipeline return self.benchmark.run_single(pipeline, filepath, progress) @@ -507,12 +524,14 @@ def __call__( num_audio_files = len(audio_file_paths) # Workaround for multiprocessing with GPU - torch.multiprocessing.set_start_method('spawn') + torch.multiprocessing.set_start_method("spawn") # For Windows support freeze_support() # Create the pool of workers using a lock for parallel tqdm usage - pool = Pool(processes=self.num_workers, initargs=(RLock(),), initializer=tqdm.set_lock) + pool = Pool( + processes=self.num_workers, initargs=(RLock(),), initializer=tqdm.set_lock + ) # Determine the arguments for each job arg_list = [ ( diff --git a/src/diart/models.py b/src/diart/models.py index 42056c44..5577a097 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -6,6 +6,7 @@ try: import pyannote.audio.pipelines.utils as pyannote_loader + _has_pyannote = True except ImportError: _has_pyannote = False @@ -48,8 +49,11 @@ class SegmentationModel(LazyModel): """ Minimal interface for a segmentation model. """ + @staticmethod - def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'SegmentationModel': + def from_pyannote( + model, use_hf_token: Union[Text, bool, None] = True + ) -> "SegmentationModel": """ Returns a `SegmentationModel` wrapping a pyannote model. @@ -115,8 +119,11 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: class EmbeddingModel(LazyModel): """Minimal interface for an embedding model.""" + @staticmethod - def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'EmbeddingModel': + def from_pyannote( + model, use_hf_token: Union[Text, bool, None] = True + ) -> "EmbeddingModel": """ Returns an `EmbeddingModel` wrapping a pyannote model. @@ -138,9 +145,7 @@ def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'Embed @abstractmethod def forward( - self, - waveform: torch.Tensor, - weights: Optional[torch.Tensor] = None + self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Forward pass of an embedding model with optional weights. diff --git a/src/diart/operators.py b/src/diart/operators.py index 6d73fc9d..7ce13285 100644 --- a/src/diart/operators.py +++ b/src/diart/operators.py @@ -25,20 +25,24 @@ def initial(): def has_samples(num_samples: int): def call_fn(state) -> bool: return state.chunk is not None and state.chunk.shape[1] == num_samples + return call_fn @staticmethod def to_sliding_window(sample_rate: int): def call_fn(state) -> SlidingWindowFeature: - resolution = SlidingWindow(start=state.start_time, duration=1. / sample_rate, step=1. / sample_rate) + resolution = SlidingWindow( + start=state.start_time, + duration=1.0 / sample_rate, + step=1.0 / sample_rate, + ) return SlidingWindowFeature(state.chunk.T, resolution) + return call_fn def rearrange_audio_stream( - duration: float = 5, - step: float = 0.5, - sample_rate: int = 16000 + duration: float = 5, step: float = 0.5, sample_rate: int = 16000 ) -> Operator: chunk_samples = int(round(sample_rate * duration)) step_samples = int(round(sample_rate * step)) @@ -49,11 +53,17 @@ def rearrange_audio_stream( def accumulate(state: AudioBufferState, value: np.ndarray): # State contains the last emitted chunk, the current step buffer and the last starting time if value.ndim != 2 or value.shape[0] != 1: - raise ValueError(f"Waveform must have shape (1, samples) but {value.shape} was found") + raise ValueError( + f"Waveform must have shape (1, samples) but {value.shape} was found" + ) start_time = state.start_time # Add new samples to the buffer - buffer = value if state.buffer is None else np.concatenate([state.buffer, value], axis=1) + buffer = ( + value + if state.buffer is None + else np.concatenate([state.buffer, value], axis=1) + ) # Check for buffer overflow if buffer.shape[1] >= step_samples: @@ -86,7 +96,7 @@ def accumulate(state: AudioBufferState, value: np.ndarray): ops.filter(AudioBufferState.has_samples(chunk_samples)), ops.filter(lambda state: state.changed), # Transform state into a SlidingWindowFeature containing the new chunk - ops.map(AudioBufferState.to_sliding_window(sample_rate)) + ops.map(AudioBufferState.to_sliding_window(sample_rate)), ) @@ -96,6 +106,7 @@ def accumulate(state: List[Any], value: Any) -> List[Any]: if len(new_state) > n: return new_state[1:] return new_state + return rx.pipe(ops.scan(accumulate, [])) @@ -117,17 +128,19 @@ class OutputAccumulationState: next_sample: Optional[int] @staticmethod - def initial() -> 'OutputAccumulationState': + def initial() -> "OutputAccumulationState": return OutputAccumulationState(None, None, 0, 0) @property def cropped_waveform(self) -> SlidingWindowFeature: return SlidingWindowFeature( - self.waveform[:self.next_sample], + self.waveform[: self.next_sample], self.waveform.sliding_window, ) - def to_tuple(self) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]: + def to_tuple( + self, + ) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]: return self.annotation, self.cropped_waveform, self.real_time @@ -153,9 +166,10 @@ def accumulate_output( ------- A reactive x operator implementing this behavior. """ + def accumulate( state: OutputAccumulationState, - value: Tuple[Annotation, Optional[SlidingWindowFeature]] + value: Tuple[Annotation, Optional[SlidingWindowFeature]], ) -> OutputAccumulationState: value = PredictionWithAudio(*value) annotation, waveform = None, None @@ -187,7 +201,7 @@ def accumulate( (state.waveform.data, np.zeros_like(state.waveform.data)), axis=0 ) # Copy chunk into buffer - waveform[state.next_sample:new_next_sample] = value.waveform.data + waveform[state.next_sample : new_next_sample] = value.waveform.data waveform = SlidingWindowFeature(waveform, sw_holder.waveform.sliding_window) return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) @@ -234,14 +248,14 @@ def buffer_output( def accumulate( state: OutputAccumulationState, - value: Tuple[Annotation, Optional[SlidingWindowFeature]] + value: Tuple[Annotation, Optional[SlidingWindowFeature]], ) -> OutputAccumulationState: value = PredictionWithAudio(*value) annotation, waveform = None, None # Determine the real time of the stream and the start time of the buffer real_time = duration if state.annotation is None else state.real_time + step - start_time = max(0., real_time - latency - duration) + start_time = max(0.0, real_time - latency - duration) # Update annotation and constrain its bounds to the buffer if state.annotation is None: @@ -267,7 +281,7 @@ def accumulate( elif state.next_sample <= num_samples: # The buffer isn't full, copy into next free buffer chunk waveform = state.waveform.data - waveform[state.next_sample:new_next_sample] = value.waveform.data + waveform[state.next_sample : new_next_sample] = value.waveform.data else: # The buffer is full, shift values to the left and copy into last buffer chunk waveform = np.roll(state.waveform.data, -num_step_samples, axis=0) @@ -277,7 +291,9 @@ def accumulate( waveform[-num_step_samples:] = value.waveform.data[:num_step_samples] # Wrap waveform in a sliding window feature to include timestamps - window = SlidingWindow(start=start_time, duration=resolution, step=resolution) + window = SlidingWindow( + start=start_time, duration=resolution, step=resolution + ) waveform = SlidingWindowFeature(waveform, window) return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) diff --git a/src/diart/optim.py b/src/diart/optim.py index 86492627..ca61d744 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -54,8 +54,10 @@ def __init__( # Make sure hyper-parameters exist in the configuration class given possible_hparams = vars(self.base_config) for param in self.hparams: - msg = f"Hyper-parameter {param.name} not found " \ - f"in configuration {self.base_config.__class__.__name__}" + msg = ( + f"Hyper-parameter {param.name} not found " + f"in configuration {self.base_config.__class__.__name__}" + ) assert param.name in possible_hparams, msg self._progress: Optional[tqdm] = None @@ -129,8 +131,11 @@ def __call__(self, num_iter: int, show_progress: bool = True): self._progress.set_description(f"Trial {last_trial + 1}") # Start with base config hyper-parameters if config was given if self.do_kickstart_hparams: - self.study.enqueue_trial({ - param.name: getattr(self.base_config, param.name) - for param in self.hparams - }, skip_if_exists=True) + self.study.enqueue_trial( + { + param.name: getattr(self.base_config, param.name) + for param in self.hparams + }, + skip_if_exists=True, + ) self.study.optimize(self.objective, num_iter, callbacks=[self._callback]) diff --git a/src/diart/progress.py b/src/diart/progress.py index ca62e8bb..d29080f9 100644 --- a/src/diart/progress.py +++ b/src/diart/progress.py @@ -8,7 +8,13 @@ class ProgressBar(ABC): @abstractmethod - def create(self, total: int, description: Optional[Text] = None, unit: Text = "it", **kwargs): + def create( + self, + total: int, + description: Optional[Text] = None, + unit: Text = "it", + **kwargs, + ): pass @abstractmethod @@ -75,7 +81,13 @@ def initial_description(self) -> Optional[Text]: return f"[{self.color}]{self.description}" return self.description - def create(self, total: int, description: Optional[Text] = None, unit: Text = "it", **kwargs): + def create( + self, + total: int, + description: Optional[Text] = None, + unit: Text = "it", + **kwargs, + ): if self.task_id is None: self.task_id = self.bar.add_task( self.resolve_description(f"[{self.color}]{description}"), @@ -83,7 +95,7 @@ def create(self, total: int, description: Optional[Text] = None, unit: Text = "i total=total, completed=0, visible=True, - **kwargs + **kwargs, ) def start(self): @@ -112,7 +124,7 @@ def __init__( description: Optional[Text] = None, leave: bool = True, position: Optional[int] = None, - do_close: bool = True + do_close: bool = True, ): self.description = description self.leave = leave @@ -128,7 +140,13 @@ def default_description(self) -> Text: def initial_description(self) -> Optional[Text]: return self.description - def create(self, total: int, description: Optional[Text] = None, unit: Optional[Text] = "it", **kwargs): + def create( + self, + total: int, + description: Optional[Text] = None, + unit: Optional[Text] = "it", + **kwargs, + ): if self.pbar is None: self.pbar = tqdm( desc=self.resolve_description(description), @@ -136,7 +154,7 @@ def create(self, total: int, description: Optional[Text] = None, unit: Optional[ unit=unit, leave=self.leave, position=self.position, - **kwargs + **kwargs, ) def start(self): diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 63c170d0..ed4e2ea0 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -41,14 +41,14 @@ def patch(self): if annotations: annotation = annotations[0] annotation.uri = self.uri - with open(self.path, 'w') as file: + with open(self.path, "w") as file: annotation.support(self.patch_collar).write_rttm(file) def on_next(self, value: Union[Tuple, Annotation]): prediction = _extract_prediction(value) # Write prediction in RTTM format prediction.uri = self.uri - with open(self.path, 'a') as file: + with open(self.path, "a") as file: prediction.write_rttm(file) def on_error(self, error: Exception): @@ -121,10 +121,12 @@ def _init_num_axs(self): def _init_figure(self): self._init_num_axs() - self.figure, self.axs = plt.subplots(self.num_axs, 1, figsize=(10, 2 * self.num_axs)) + self.figure, self.axs = plt.subplots( + self.num_axs, 1, figsize=(10, 2 * self.num_axs) + ) if self.num_axs == 1: self.axs = [self.axs] - self.figure.canvas.mpl_connect('close_event', self._on_window_closed) + self.figure.canvas.mpl_connect("close_event", self._on_window_closed) def _clear_axs(self): for i in range(self.num_axs): @@ -134,13 +136,10 @@ def get_plot_bounds(self, real_time: float) -> Segment: start_time = 0 end_time = real_time - self.latency if self.visualization == "slide": - start_time = max(0., end_time - self.window_duration) + start_time = max(0.0, end_time - self.window_duration) return Segment(start_time, end_time) - def on_next( - self, - values: Tuple[Annotation, SlidingWindowFeature, float] - ): + def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): if self.window_closed: raise WindowClosedException diff --git a/src/diart/sources.py b/src/diart/sources.py index 5ae6c0eb..82051b2e 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -25,6 +25,7 @@ class AudioSource(ABC): sample_rate: int Sample rate of the audio source. """ + def __init__(self, uri: Text, sample_rate: int): self.uri = uri self.sample_rate = sample_rate @@ -62,6 +63,7 @@ class FileAudioSource(AudioSource): Duration of each emitted chunk in seconds. Defaults to 0.5 seconds. """ + def __init__( self, file: FilePath, @@ -108,9 +110,13 @@ def read(self): # Add last incomplete chunk with padding if num_samples % self.block_size != 0: - last_chunk = waveform[:, chunks.shape[0] * self.block_size:].unsqueeze(0).numpy() + last_chunk = ( + waveform[:, chunks.shape[0] * self.block_size :].unsqueeze(0).numpy() + ) diff_samples = self.block_size - last_chunk.shape[-1] - last_chunk = np.concatenate([last_chunk, np.zeros((1, 1, diff_samples))], axis=-1) + last_chunk = np.concatenate( + [last_chunk, np.zeros((1, 1, diff_samples))], axis=-1 + ) chunks = np.vstack([chunks, last_chunk]) # Stream blocks @@ -215,13 +221,14 @@ class WebSocketAudioSource(AudioSource): Path to a certificate if using SSL. Defaults to no certificate. """ + def __init__( self, sample_rate: int, host: Text = "127.0.0.1", port: int = 7007, key: Optional[Union[Text, Path]] = None, - certificate: Optional[Union[Text, Path]] = None + certificate: Optional[Union[Text, Path]] = None, ): # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities. # I would prefer the client to send a JSON with data and sample rate, then resample if needed @@ -300,3 +307,16 @@ def read(self): def close(self): self.is_closed = True + + +class AppleDeviceAudioSource(TorchStreamAudioSource): + def __init__( + self, + sample_rate: int, + device: str = "0:0", + stream_index: int = 0, + block_duration: float = 0.5, + ): + uri = f"apple_input_device:{device}" + streamer = StreamReader(device, format="avfoundation") + super().__init__(uri, sample_rate, streamer, stream_index, block_duration) diff --git a/src/diart/utils.py b/src/diart/utils.py index e825ef29..f0eb4751 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -102,6 +102,7 @@ def apply(feature: SlidingWindowFeature): notebook.plot_feature(feature) plt.tight_layout() plt.show() + return apply @@ -116,4 +117,5 @@ def apply(annotation: Annotation): notebook.plot_annotation(annotation) plt.tight_layout() plt.show() + return apply From 45f8ad9cb227c96fe9c184542e2ffc0ad1aff4ec Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Wed, 11 Oct 2023 17:15:00 +0200 Subject: [PATCH 11/14] Catch keyboard interrupt in diart.stream (#183) --- src/diart/console/stream.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index f7d96360..87f3a0a1 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -122,7 +122,10 @@ def run(): inference.attach_observers( RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm") ) - inference() + try: + inference() + except KeyboardInterrupt: + pass if __name__ == "__main__": From 0113ab2d56398224157615a70325802602a3ec91 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Thu, 19 Oct 2023 18:20:57 +0200 Subject: [PATCH 12/14] Remove `PipelineConfig.from_dict()` (#189) * Unify hparam naming. Clean some typing annotations * Remove config.from_dict(). Add --duration argument to CLI * Update README.md accordingly --- README.md | 2 +- src/diart/blocks/base.py | 10 +--- src/diart/blocks/diarization.py | 83 +++++++++------------------------ src/diart/blocks/vad.py | 60 +++++++----------------- src/diart/console/benchmark.py | 24 ++++++++-- src/diart/console/client.py | 4 +- src/diart/console/serve.py | 24 ++++++++-- src/diart/console/stream.py | 23 +++++++-- src/diart/console/tune.py | 26 +++++++++-- src/diart/utils.py | 8 +--- 10 files changed, 124 insertions(+), 140 deletions(-) diff --git a/README.md b/README.md index caef5045..3d31bc6d 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,7 @@ To obtain the best results, make sure to use the following hyper-parameters: `diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration: ```shell -diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021 +diart.benchmark /wav/dir --reference /rttm/dir --tau-active=0.555 --rho-update=0.422 --delta-new=1.517 --segmentation pyannote/segmentation@Interspeech2021 ``` or using the inference API: diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 6536a3f7..f6ca3a33 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -1,8 +1,7 @@ -from typing import Any, Tuple, Sequence, Text -from dataclasses import dataclass from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Tuple, Sequence, Text -import numpy as np from pyannote.core import SlidingWindowFeature from pyannote.metrics.base import BaseMetric @@ -53,11 +52,6 @@ def latency(self) -> float: def sample_rate(self) -> int: pass - @staticmethod - @abstractmethod - def from_dict(data: Any) -> "PipelineConfig": - pass - def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) right = utils.get_padding_right(self.latency, self.step) diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 3cf4e333..fab83c36 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -1,4 +1,6 @@ -from typing import Optional, Tuple, Sequence, Union, Any +from __future__ import annotations + +from typing import Sequence import numpy as np import torch @@ -14,40 +16,37 @@ from .segmentation import SpeakerSegmentation from .utils import Binarize from .. import models as m -from .. import utils class SpeakerDiarizationConfig(base.PipelineConfig): def __init__( self, - segmentation: Optional[m.SegmentationModel] = None, - embedding: Optional[m.EmbeddingModel] = None, - duration: Optional[float] = None, + segmentation: m.SegmentationModel | None = None, + embedding: m.EmbeddingModel | None = None, + duration: float | None = None, step: float = 0.5, - latency: Optional[Union[float, Literal["max", "min"]]] = None, + latency: float | Literal["max", "min"] | None = None, tau_active: float = 0.6, rho_update: float = 0.3, delta_new: float = 1, gamma: float = 3, beta: float = 10, max_speakers: int = 20, - device: Optional[torch.device] = None, + device: torch.device | None = None, **kwargs, ): # Default segmentation model is pyannote/segmentation - self.segmentation = segmentation - if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote( - "pyannote/segmentation" - ) - - self._duration = duration - self._sample_rate: Optional[int] = None + self.segmentation = segmentation or m.SegmentationModel.from_pyannote( + "pyannote/segmentation" + ) # Default embedding model is pyannote/embedding - self.embedding = embedding - if self.embedding is None: - self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") + self.embedding = embedding or m.EmbeddingModel.from_pyannote( + "pyannote/embedding" + ) + + self._duration = duration + self._sample_rate: int | None = None # Latency defaults to the step duration self._step = step @@ -64,48 +63,8 @@ def __init__( self.beta = beta self.max_speakers = max_speakers - self.device = device - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - @staticmethod - def from_dict(data: Any) -> "SpeakerDiarizationConfig": - # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None - device = utils.get(data, "device", None) - if device is None: - device = torch.device("cpu") if utils.get(data, "cpu", False) else None - - # Instantiate models - hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) - segmentation = utils.get(data, "segmentation", "pyannote/segmentation") - segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) - embedding = utils.get(data, "embedding", "pyannote/embedding") - embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token) - - # Hyper-parameters and their aliases - tau = utils.get(data, "tau_active", None) - if tau is None: - tau = utils.get(data, "tau", 0.6) - rho = utils.get(data, "rho_update", None) - if rho is None: - rho = utils.get(data, "rho", 0.3) - delta = utils.get(data, "delta_new", None) - if delta is None: - delta = utils.get(data, "delta", 1) - - return SpeakerDiarizationConfig( - segmentation=segmentation, - embedding=embedding, - duration=utils.get(data, "duration", None), - step=utils.get(data, "step", 0.5), - latency=utils.get(data, "latency", None), - tau_active=tau, - rho_update=rho, - delta_new=delta, - gamma=utils.get(data, "gamma", 3), - beta=utils.get(data, "beta", 10), - max_speakers=utils.get(data, "max_speakers", 20), - device=device, + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" ) @property @@ -132,7 +91,7 @@ def sample_rate(self) -> int: class SpeakerDiarization(base.Pipeline): - def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): + def __init__(self, config: SpeakerDiarizationConfig | None = None): self._config = SpeakerDiarizationConfig() if config is None else config msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" @@ -200,7 +159,7 @@ def reset(self): def __call__( self, waveforms: Sequence[SlidingWindowFeature] - ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: + ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" assert batch_size >= 1, msg diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index 04fe5608..0edd3e0b 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -1,4 +1,6 @@ -from typing import Any, Optional, Union, Sequence, Tuple +from __future__ import annotations + +from typing import Sequence import numpy as np import torch @@ -13,8 +15,8 @@ from pyannote.metrics.detection import DetectionErrorRate from typing_extensions import Literal -from .aggregation import DelayedAggregation from . import base +from .aggregation import DelayedAggregation from .segmentation import SpeakerSegmentation from .utils import Binarize from .. import models as m @@ -24,24 +26,22 @@ class VoiceActivityDetectionConfig(base.PipelineConfig): def __init__( self, - segmentation: Optional[m.SegmentationModel] = None, - duration: Optional[float] = None, + segmentation: m.SegmentationModel | None = None, + duration: float | None = None, step: float = 0.5, - latency: Optional[Union[float, Literal["max", "min"]]] = None, + latency: float | Literal["max", "min"] | None = None, tau_active: float = 0.6, - device: Optional[torch.device] = None, + device: torch.device | None = None, **kwargs, ): # Default segmentation model is pyannote/segmentation - self.segmentation = segmentation - if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote( - "pyannote/segmentation" - ) + self.segmentation = segmentation or m.SegmentationModel.from_pyannote( + "pyannote/segmentation" + ) self._duration = duration self._step = step - self._sample_rate: Optional[int] = None + self._sample_rate: int | None = None # Latency defaults to the step duration self._latency = latency @@ -51,9 +51,9 @@ def __init__( self._latency = self._duration self.tau_active = tau_active - self.device = device - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) @property def duration(self) -> float: @@ -77,35 +77,9 @@ def sample_rate(self) -> int: self._sample_rate = self.segmentation.sample_rate return self._sample_rate - @staticmethod - def from_dict(data: Any) -> "VoiceActivityDetectionConfig": - # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None - device = utils.get(data, "device", None) - if device is None: - device = torch.device("cpu") if utils.get(data, "cpu", False) else None - - # Instantiate segmentation model - hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) - segmentation = utils.get(data, "segmentation", "pyannote/segmentation") - segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) - - # Tau active and its alias - tau = utils.get(data, "tau_active", None) - if tau is None: - tau = utils.get(data, "tau", 0.6) - - return VoiceActivityDetectionConfig( - segmentation=segmentation, - duration=utils.get(data, "duration", None), - step=utils.get(data, "step", 0.5), - latency=utils.get(data, "latency", None), - tau_active=tau, - device=device, - ) - class VoiceActivityDetection(base.Pipeline): - def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): + def __init__(self, config: VoiceActivityDetectionConfig | None = None): self._config = VoiceActivityDetectionConfig() if config is None else config msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" @@ -158,7 +132,7 @@ def set_timestamp_shift(self, shift: float): def __call__( self, waveforms: Sequence[SlidingWindowFeature], - ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: + ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" assert batch_size >= 1, msg diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index 70b4c3d9..b5a296d1 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -2,7 +2,10 @@ from pathlib import Path import pandas as pd +import torch + from diart import argdoc +from diart import models as m from diart import utils from diart.inference import Benchmark, Parallelize @@ -37,6 +40,11 @@ def run(): type=Path, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files", ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) parser.add_argument( "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" ) @@ -44,13 +52,13 @@ def run(): "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" ) parser.add_argument( - "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" ) parser.add_argument( - "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" ) parser.add_argument( - "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" ) parser.add_argument( "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" @@ -93,6 +101,14 @@ def run(): ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + pipeline_class = utils.get_pipeline_class(args.pipeline) benchmark = Benchmark( @@ -104,7 +120,7 @@ def run(): batch_size=args.batch_size, ) - config = pipeline_class.get_config_class().from_dict(vars(args)) + config = pipeline_class.get_config_class()(**vars(args)) if args.num_workers > 0: benchmark = Parallelize(benchmark, args.num_workers) diff --git a/src/diart/console/client.py b/src/diart/console/client.py index b656298a..b3de36db 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -3,12 +3,12 @@ from threading import Thread from typing import Text, Optional -import numpy as np import rx.operators as ops +from websocket import WebSocket + from diart import argdoc from diart import sources as src from diart import utils -from websocket import WebSocket def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int): diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index d8c059c3..bc002e42 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -1,7 +1,10 @@ import argparse from pathlib import Path +import torch + from diart import argdoc +from diart import models as m from diart import sources as src from diart import utils from diart.inference import StreamingInference @@ -30,6 +33,11 @@ def run(): type=str, help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) parser.add_argument( "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" ) @@ -37,13 +45,13 @@ def run(): "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" ) parser.add_argument( - "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" ) parser.add_argument( - "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" ) parser.add_argument( - "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" ) parser.add_argument( "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" @@ -74,9 +82,17 @@ def run(): ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + # Resolve pipeline pipeline_class = utils.get_pipeline_class(args.pipeline) - config = pipeline_class.get_config_class().from_dict(vars(args)) + config = pipeline_class.get_config_class()(**vars(args)) pipeline = pipeline_class(config) # Create websocket audio source diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index 87f3a0a1..713f3e99 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -1,7 +1,10 @@ import argparse from pathlib import Path +import torch + from diart import argdoc +from diart import models as m from diart import sources as src from diart import utils from diart.inference import StreamingInference @@ -34,7 +37,9 @@ def run(): help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", ) parser.add_argument( - "--duration", type=float, help=f"{argdoc.DURATION}. Defaults to training segmentation duration" + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", ) parser.add_argument( "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" @@ -43,13 +48,13 @@ def run(): "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" ) parser.add_argument( - "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" ) parser.add_argument( - "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" ) parser.add_argument( - "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" ) parser.add_argument( "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" @@ -88,9 +93,17 @@ def run(): ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + # Resolve pipeline pipeline_class = utils.get_pipeline_class(args.pipeline) - config = pipeline_class.get_config_class().from_dict(vars(args)) + config = pipeline_class.get_config_class()(**vars(args)) pipeline = pipeline_class(config) # Manage audio source diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index 4c969efa..ec243348 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -2,11 +2,14 @@ from pathlib import Path import optuna +import torch +from optuna.samplers import TPESampler + from diart import argdoc +from diart import models as m from diart import utils from diart.blocks.base import HyperParameter from diart.optim import Optimizer -from optuna.samplers import TPESampler def run(): @@ -40,6 +43,11 @@ def run(): type=str, help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) parser.add_argument( "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" ) @@ -47,13 +55,13 @@ def run(): "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" ) parser.add_argument( - "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" ) parser.add_argument( - "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" ) parser.add_argument( - "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" ) parser.add_argument( "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" @@ -102,11 +110,19 @@ def run(): ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + # Retrieve pipeline class pipeline_class = utils.get_pipeline_class(args.pipeline) # Create the base configuration for each trial - base_config = pipeline_class.get_config_class().from_dict(vars(args)) + base_config = pipeline_class.get_config_class()(**vars(args)) # Create hyper-parameters to optimize possible_hparams = pipeline_class.hyper_parameters() diff --git a/src/diart/utils.py b/src/diart/utils.py index f0eb4751..ca27d022 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -1,13 +1,13 @@ import base64 import time -from typing import Optional, Text, Union, Any, Dict +from typing import Optional, Text, Union import matplotlib.pyplot as plt import numpy as np from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook -from .progress import ProgressBar from . import blocks +from .progress import ProgressBar class Chronometer: @@ -53,10 +53,6 @@ def parse_hf_token_arg(hf_token: Union[bool, Text]) -> Union[bool, Text]: return hf_token -def get(data: Dict[Text, Any], key: Text, default: Any) -> Any: - return data[key] if key in data else default - - def encode_audio(waveform: np.ndarray) -> Text: data = waveform.astype(np.float32).tobytes() return base64.b64encode(data).decode("utf-8") From c90a0dca032016720c97115eb89a1131d93db5b5 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Thu, 19 Oct 2023 18:26:17 +0200 Subject: [PATCH 13/14] Bump up version to 0.8.0 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index c2e536d8..f38a612e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name=diart -version=0.7.0 +version=0.8.0 author=Juan Manuel Coria description=Streaming speaker diarization in real-time long_description=file: README.md From b26e60c39ca4abf3406d04300d518858ece900ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 26 Oct 2023 14:09:23 +0200 Subject: [PATCH 14/14] Fix link to reproducibility section (#191) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3d31bc6d..f9d87b91 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm")) prediction = inference() ``` -For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)). +For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#-reproducibility)). ## 🤖 Add your model