diff --git a/README.md b/README.md index ef533946..f9d87b91 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ | - 🤖 Custom models + 🤖 Add your model | @@ -64,17 +64,11 @@ 1) Create environment: ```shell -conda create -n diart python=3.8 +conda env create -f diart/environment.yml conda activate diart ``` -2) Install audio libraries: - -```shell -conda install portaudio pysoundfile ffmpeg -c conda-forge -``` - -3) Install diart: +2) Install the package: ```shell pip install diart ``` @@ -110,32 +104,32 @@ See `diart.stream -h` for more options. ### From python -Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk: +Use `StreamingInference` to run a pipeline on an audio source and write the results to disk: ```python -from diart import OnlineSpeakerDiarization +from diart import SpeakerDiarization from diart.sources import MicrophoneAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference from diart.sinks import RTTMWriter -pipeline = OnlineSpeakerDiarization() -mic = MicrophoneAudioSource(pipeline.config.sample_rate) -inference = RealTimeInference(pipeline, mic, do_plot=True) +pipeline = SpeakerDiarization() +mic = MicrophoneAudioSource() +inference = StreamingInference(pipeline, mic, do_plot=True) inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm")) prediction = inference() ``` -For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)). +For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#-reproducibility)). -## 🤖 Custom models +## 🤖 Add your model -Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses): +Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`): ```python -from diart import OnlineSpeakerDiarization, PipelineConfig +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.models import EmbeddingModel, SegmentationModel from diart.sources import MicrophoneAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference def model_loader(): @@ -168,19 +162,19 @@ class MyEmbeddingModel(EmbeddingModel): return self.model(waveform, weights) -config = PipelineConfig( +config = SpeakerDiarizationConfig( segmentation=MySegmentationModel(), embedding=MyEmbeddingModel() ) -pipeline = OnlineSpeakerDiarization(config) -mic = MicrophoneAudioSource(config.sample_rate) -inference = RealTimeInference(pipeline, mic) +pipeline = SpeakerDiarization(config) +mic = MicrophoneAudioSource() +inference = StreamingInference(pipeline, mic) prediction = inference() ``` ## 📈 Tune hyper-parameters -Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset. +Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs. ### From the command line @@ -247,12 +241,11 @@ from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation") embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding") -sample_rate = segmentation.model.sample_rate -mic = MicrophoneAudioSource(sample_rate) +mic = MicrophoneAudioSource() stream = mic.stream.pipe( # Reformat stream to 5s duration and 500ms shift - dops.rearrange_audio_stream(sample_rate=sample_rate), + dops.rearrange_audio_stream(sample_rate=segmentation.model.sample_rate), ops.map(lambda wav: (wav, segmentation(wav))), ops.starmap(embedding) ).subscribe(on_next=lambda emb: print(emb.shape)) @@ -281,7 +274,7 @@ diart.serve --host 0.0.0.0 --port 7007 diart.client microphone --host --port 7007 ``` -**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`. +**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`. See `-h` for more options. @@ -290,13 +283,13 @@ See `-h` for more options. For customized solutions, a server can also be created in python using the `WebSocketAudioSource`: ```python -from diart import OnlineSpeakerDiarization +from diart import SpeakerDiarization from diart.sources import WebSocketAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference -pipeline = OnlineSpeakerDiarization() +pipeline = SpeakerDiarization() source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007) -inference = RealTimeInference(pipeline, source) +inference = StreamingInference(pipeline, source) inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm())) prediction = inference() ``` @@ -347,21 +340,21 @@ To obtain the best results, make sure to use the following hyper-parameters: `diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration: ```shell -diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021 +diart.benchmark /wav/dir --reference /rttm/dir --tau-active=0.555 --rho-update=0.422 --delta-new=1.517 --segmentation pyannote/segmentation@Interspeech2021 ``` or using the inference API: ```python from diart.inference import Benchmark, Parallelize -from diart import OnlineSpeakerDiarization, PipelineConfig +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.models import SegmentationModel benchmark = Benchmark("/wav/dir", "/rttm/dir") name = "pyannote/segmentation@Interspeech2021" segmentation = SegmentationModel.from_pyannote(name) -config = PipelineConfig( +config = SpeakerDiarizationConfig( # Set the model used in the paper segmentation=segmentation, step=0.5, @@ -370,12 +363,12 @@ config = PipelineConfig( rho_update=0.422, delta_new=1.517 ) -benchmark(OnlineSpeakerDiarization, config) +benchmark(SpeakerDiarization, config) # Run the same benchmark in parallel p_benchmark = Parallelize(benchmark, num_workers=4) if __name__ == "__main__": # Needed for multiprocessing - p_benchmark(OnlineSpeakerDiarization, config) + p_benchmark(SpeakerDiarization, config) ``` This pre-calculates model outputs in batches, so it runs a lot faster. diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..f62b4274 --- /dev/null +++ b/environment.yml @@ -0,0 +1,12 @@ +name: diart +channels: + - conda-forge + - defaults +dependencies: + - python=3.8 + - portaudio=19.6.* + - pysoundfile=0.12.* + - ffmpeg[version='<4.4'] + - pip + - pip: + - . \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 50662023..e0d93213 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ tqdm>=4.64.0 pandas>=1.4.2 torch>=1.12.1 torchvision>=0.14.0 -torchaudio>=0.12.1,<1.0 +torchaudio>=2.0.2 pyannote.audio>=2.1.1 pyannote.core>=4.5 pyannote.database>=4.1.1 diff --git a/setup.cfg b/setup.cfg index 594c876e..f38a612e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,12 +1,12 @@ [metadata] name=diart -version=0.7.0 +version=0.8.0 author=Juan Manuel Coria -description=Speaker diarization in real time +description=Streaming speaker diarization in real-time long_description=file: README.md long_description_content_type=text/markdown keywords=speaker diarization, streaming, online, real time, rxpy -url=https://github.com/juanmc2005/StreamingSpeakerDiarization +url=https://github.com/juanmc2005/diart license=MIT classifiers= Development Status :: 4 - Beta @@ -30,7 +30,7 @@ install_requires= pandas>=1.4.2 torch>=1.12.1 torchvision>=0.14.0 - torchaudio>=0.12.1,<1.0 + torchaudio>=2.0.2 pyannote.audio>=2.1.1 pyannote.core>=4.5 pyannote.database>=4.1.1 diff --git a/src/diart/__init__.py b/src/diart/__init__.py index c9692638..4bd51327 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1,6 +1,8 @@ from .blocks import ( - OnlineSpeakerDiarization, - BasePipeline, + SpeakerDiarization, + Pipeline, + SpeakerDiarizationConfig, PipelineConfig, - BasePipelineConfig, + VoiceActivityDetection, + VoiceActivityDetectionConfig, ) diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py index d16df41e..e89caa28 100644 --- a/src/diart/argdoc.py +++ b/src/diart/argdoc.py @@ -1,5 +1,6 @@ SEGMENTATION = "Segmentation model name from pyannote" EMBEDDING = "Embedding model name from pyannote" +DURATION = "Chunk duration (in seconds)" STEP = "Sliding window step (in seconds)" LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION" TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1" diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index 59a6ef36..15cf81d9 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -13,6 +13,7 @@ OverlapAwareSpeakerEmbedding, ) from .segmentation import SpeakerSegmentation -from .diarization import OnlineSpeakerDiarization, BasePipeline -from .config import BasePipelineConfig, PipelineConfig +from .diarization import SpeakerDiarization, SpeakerDiarizationConfig +from .base import PipelineConfig, Pipeline from .utils import Binarize, Resample, AdjustVolume +from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py index b6352a28..aa5e6a1e 100644 --- a/src/diart/blocks/aggregation.py +++ b/src/diart/blocks/aggregation.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Optional, List import numpy as np @@ -5,7 +6,7 @@ from typing_extensions import Literal -class AggregationStrategy: +class AggregationStrategy(ABC): """Abstract class representing a strategy to aggregate overlapping buffers Parameters @@ -17,14 +18,18 @@ class AggregationStrategy: """ def __init__(self, cropping_mode: Literal["strict", "loose", "center"] = "loose"): - assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`" + assert cropping_mode in [ + "strict", + "loose", + "center", + ], f"Invalid cropping mode `{cropping_mode}`" self.cropping_mode = cropping_mode @staticmethod def build( name: Literal["mean", "hamming", "first"], - cropping_mode: Literal["strict", "loose", "center"] = "loose" - ) -> 'AggregationStrategy': + cropping_mode: Literal["strict", "loose", "center"] = "loose", + ) -> "AggregationStrategy": """Build an AggregationStrategy instance based on its name""" assert name in ("mean", "hamming", "first") if name == "mean": @@ -34,7 +39,9 @@ def build( else: return FirstOnlyStrategy(cropping_mode) - def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> SlidingWindowFeature: + def __call__( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> SlidingWindowFeature: """Aggregate chunks over a specific region. Parameters @@ -52,20 +59,23 @@ def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> Slidi aggregation = self.aggregate(buffers, focus) resolution = focus.duration / aggregation.shape[0] resolution = SlidingWindow( - start=focus.start, - duration=resolution, - step=resolution + start=focus.start, duration=resolution, step=resolution ) return SlidingWindowFeature(aggregation, resolution) - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: - raise NotImplementedError + @abstractmethod + def aggregate( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> np.ndarray: + pass class HammingWeightedAverageStrategy(AggregationStrategy): """Compute the average weighted by the corresponding Hamming-window aligned to each buffer""" - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + def aggregate( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> np.ndarray: num_frames, num_speakers = buffers[0].data.shape hamming, intersection = [], [] for buffer in buffers: @@ -85,19 +95,25 @@ def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.n class AverageStrategy(AggregationStrategy): """Compute a simple average over the focus region""" - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + def aggregate( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> np.ndarray: # Stack all overlapping regions - intersection = np.stack([ - buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration) - for buffer in buffers - ]) + intersection = np.stack( + [ + buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration) + for buffer in buffers + ] + ) return np.mean(intersection, axis=0) class FirstOnlyStrategy(AggregationStrategy): """Instead of aggregating, keep the first focus region in the buffer list""" - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + def aggregate( + self, buffers: List[SlidingWindowFeature], focus: Segment + ) -> np.ndarray: return buffers[0].crop(focus, mode=self.cropping_mode, fixed=focus.duration) @@ -149,12 +165,16 @@ def __init__( step: float, latency: Optional[float] = None, strategy: Literal["mean", "hamming", "first"] = "hamming", - cropping_mode: Literal["strict", "loose", "center"] = "loose" + cropping_mode: Literal["strict", "loose", "center"] = "loose", ): self.step = step self.latency = latency self.strategy = strategy - assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`" + assert cropping_mode in [ + "strict", + "loose", + "center", + ], f"Invalid cropping mode `{cropping_mode}`" self.cropping_mode = cropping_mode if self.latency is None: @@ -169,7 +189,7 @@ def _prepend( self, output_window: SlidingWindowFeature, output_region: Segment, - buffers: List[SlidingWindowFeature] + buffers: List[SlidingWindowFeature], ): # FIXME instead of prepending the output of the first chunk, # add padding of `chunk_duration - latency` seconds at the @@ -187,7 +207,7 @@ def _prepend( resolution = output_region.end / first_output.shape[0] output_window = SlidingWindowFeature( first_output, - SlidingWindow(start=0, duration=resolution, step=resolution) + SlidingWindow(start=0, duration=resolution, step=resolution), ) return output_window diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py new file mode 100644 index 00000000..f6ca3a33 --- /dev/null +++ b/src/diart/blocks/base.py @@ -0,0 +1,95 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Tuple, Sequence, Text + +from pyannote.core import SlidingWindowFeature +from pyannote.metrics.base import BaseMetric + +from .. import utils +from ..audio import FilePath, AudioLoader + + +@dataclass +class HyperParameter: + name: Text + low: float + high: float + + @staticmethod + def from_name(name: Text) -> "HyperParameter": + if name == "tau_active": + return TauActive + if name == "rho_update": + return RhoUpdate + if name == "delta_new": + return DeltaNew + raise ValueError(f"Hyper-parameter '{name}' not recognized") + + +TauActive = HyperParameter("tau_active", low=0, high=1) +RhoUpdate = HyperParameter("rho_update", low=0, high=1) +DeltaNew = HyperParameter("delta_new", low=0, high=2) + + +class PipelineConfig(ABC): + @property + @abstractmethod + def duration(self) -> float: + pass + + @property + @abstractmethod + def step(self) -> float: + pass + + @property + @abstractmethod + def latency(self) -> float: + pass + + @property + @abstractmethod + def sample_rate(self) -> int: + pass + + def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: + file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) + right = utils.get_padding_right(self.latency, self.step) + left = utils.get_padding_left(file_duration + right, self.duration) + return left, right + + +class Pipeline(ABC): + @staticmethod + @abstractmethod + def get_config_class() -> type: + pass + + @staticmethod + @abstractmethod + def suggest_metric() -> BaseMetric: + pass + + @staticmethod + @abstractmethod + def hyper_parameters() -> Sequence[HyperParameter]: + pass + + @property + @abstractmethod + def config(self) -> PipelineConfig: + pass + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def set_timestamp_shift(self, shift: float): + pass + + @abstractmethod + def __call__( + self, waveforms: Sequence[SlidingWindowFeature] + ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: + pass diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py index 882001b9..b7217c0a 100644 --- a/src/diart/blocks/clustering.py +++ b/src/diart/blocks/clustering.py @@ -27,13 +27,14 @@ class OnlineSpeakerClustering: max_speakers: int Maximum number of global speakers to track through a conversation. Defaults to 20. """ + def __init__( self, tau_active: float, rho_update: float, delta_new: float, metric: Optional[str] = "cosine", - max_speakers: int = 20 + max_speakers: int = 20, ): self.tau_active = tau_active self.rho_update = rho_update @@ -116,9 +117,7 @@ def add_center(self, embedding: np.ndarray) -> int: return center def identify( - self, - segmentation: SlidingWindowFeature, - embeddings: torch.Tensor + self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor ) -> SpeakerMap: """Identify the centroids to which the input speaker embeddings belong. @@ -135,15 +134,18 @@ def identify( A mapping from local speakers to global speakers. """ embeddings = embeddings.detach().cpu().numpy() - active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0] - long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0] + active_speakers = np.where( + np.max(segmentation.data, axis=0) >= self.tau_active + )[0] + long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[ + 0 + ] num_local_speakers = segmentation.data.shape[1] if self.centers is None: self.init_centers(embeddings.shape[1]) assignments = [ - (spk, self.add_center(embeddings[spk])) - for spk in active_speakers + (spk, self.add_center(embeddings[spk])) for spk in active_speakers ] return SpeakerMapBuilder.hard_map( shape=(num_local_speakers, self.max_speakers), @@ -154,18 +156,16 @@ def identify( # Obtain a mapping based on distances between embeddings and centers dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric) # Remove any assignments containing invalid speakers - inactive_speakers = np.array([ - spk for spk in range(num_local_speakers) - if spk not in active_speakers - ]) + inactive_speakers = np.array( + [spk for spk in range(num_local_speakers) if spk not in active_speakers] + ) dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers) # Keep assignments under the distance threshold valid_map = dist_map.unmap_threshold(self.delta_new) # Some speakers might be unidentified missed_speakers = [ - s for s in active_speakers - if not valid_map.is_source_speaker_mapped(s) + s for s in active_speakers if not valid_map.is_source_speaker_mapped(s) ] # Add assignments to new centers if possible @@ -205,8 +205,10 @@ def identify( return valid_map - def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) -> SlidingWindowFeature: + def __call__( + self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor + ) -> SlidingWindowFeature: return SlidingWindowFeature( self.identify(segmentation, embeddings).apply(segmentation.data), - segmentation.sliding_window + segmentation.sliding_window, ) diff --git a/src/diart/blocks/config.py b/src/diart/blocks/config.py deleted file mode 100644 index d8e2a656..00000000 --- a/src/diart/blocks/config.py +++ /dev/null @@ -1,153 +0,0 @@ -from typing import Any, Optional, Union, Tuple - -import numpy as np -import torch -from typing_extensions import Literal - -from .. import models as m -from .. import utils -from ..audio import FilePath, AudioLoader - - -class BasePipelineConfig: - @property - def duration(self) -> float: - raise NotImplementedError - - @property - def step(self) -> float: - raise NotImplementedError - - @property - def latency(self) -> float: - raise NotImplementedError - - @property - def sample_rate(self) -> int: - raise NotImplementedError - - @staticmethod - def from_dict(data: Any) -> 'BasePipelineConfig': - raise NotImplementedError - - def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: - file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) - right = utils.get_padding_right(self.latency, self.step) - left = utils.get_padding_left(file_duration + right, self.duration) - return left, right - - def optimal_block_size(self) -> int: - return int(np.rint(self.step * self.sample_rate)) - - -class PipelineConfig(BasePipelineConfig): - def __init__( - self, - segmentation: Optional[m.SegmentationModel] = None, - embedding: Optional[m.EmbeddingModel] = None, - duration: Optional[float] = None, - step: float = 0.5, - latency: Optional[Union[float, Literal["max", "min"]]] = None, - tau_active: float = 0.6, - rho_update: float = 0.3, - delta_new: float = 1, - gamma: float = 3, - beta: float = 10, - max_speakers: int = 20, - device: Optional[torch.device] = None, - **kwargs, - ): - # Default segmentation model is pyannote/segmentation - self.segmentation = segmentation - if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") - - # Default duration is the one given by the segmentation model - self._duration = duration - - # Expected sample rate is given by the segmentation model - self._sample_rate: Optional[int] = None - - # Default embedding model is pyannote/embedding - self.embedding = embedding - if self.embedding is None: - self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") - - # Latency defaults to the step duration - self._step = step - self._latency = latency - if self._latency is None or self._latency == "min": - self._latency = self._step - elif self._latency == "max": - self._latency = self._duration - - self.tau_active = tau_active - self.rho_update = rho_update - self.delta_new = delta_new - self.gamma = gamma - self.beta = beta - self.max_speakers = max_speakers - - self.device = device - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - @staticmethod - def from_dict(data: Any) -> 'PipelineConfig': - # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None - device = utils.get(data, "device", None) - if device is None: - device = torch.device("cpu") if utils.get(data, "cpu", False) else None - - # Instantiate models - hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) - segmentation = utils.get(data, "segmentation", "pyannote/segmentation") - segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) - embedding = utils.get(data, "embedding", "pyannote/embedding") - embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token) - - # Hyper-parameters and their aliases - tau = utils.get(data, "tau_active", None) - if tau is None: - tau = utils.get(data, "tau", 0.6) - rho = utils.get(data, "rho_update", None) - if rho is None: - rho = utils.get(data, "rho", 0.3) - delta = utils.get(data, "delta_new", None) - if delta is None: - delta = utils.get(data, "delta", 1) - - return PipelineConfig( - segmentation=segmentation, - embedding=embedding, - duration=utils.get(data, "duration", None), - step=utils.get(data, "step", 0.5), - latency=utils.get(data, "latency", None), - tau_active=tau, - rho_update=rho, - delta_new=delta, - gamma=utils.get(data, "gamma", 3), - beta=utils.get(data, "beta", 10), - max_speakers=utils.get(data, "max_speakers", 20), - device=device, - ) - - @property - def duration(self) -> float: - if self._duration is None: - self._duration = self.segmentation.duration - return self._duration - - @property - def step(self) -> float: - return self._step - - @property - def latency(self) -> float: - return self._latency - - @property - def sample_rate(self) -> int: - if self._sample_rate is None: - self._sample_rate = self.segmentation.sample_rate - return self._sample_rate diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 7f0e162c..fab83c36 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -1,49 +1,111 @@ -from typing import Optional, Tuple, Sequence +from __future__ import annotations + +from typing import Sequence import numpy as np import torch from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.diarization import DiarizationErrorRate +from typing_extensions import Literal +from . import base from .aggregation import DelayedAggregation from .clustering import OnlineSpeakerClustering from .embedding import OverlapAwareSpeakerEmbedding from .segmentation import SpeakerSegmentation from .utils import Binarize -from .config import BasePipelineConfig, PipelineConfig +from .. import models as m -class BasePipeline: - @staticmethod - def get_config_class() -> type: - raise NotImplementedError +class SpeakerDiarizationConfig(base.PipelineConfig): + def __init__( + self, + segmentation: m.SegmentationModel | None = None, + embedding: m.EmbeddingModel | None = None, + duration: float | None = None, + step: float = 0.5, + latency: float | Literal["max", "min"] | None = None, + tau_active: float = 0.6, + rho_update: float = 0.3, + delta_new: float = 1, + gamma: float = 3, + beta: float = 10, + max_speakers: int = 20, + device: torch.device | None = None, + **kwargs, + ): + # Default segmentation model is pyannote/segmentation + self.segmentation = segmentation or m.SegmentationModel.from_pyannote( + "pyannote/segmentation" + ) + + # Default embedding model is pyannote/embedding + self.embedding = embedding or m.EmbeddingModel.from_pyannote( + "pyannote/embedding" + ) + + self._duration = duration + self._sample_rate: int | None = None + + # Latency defaults to the step duration + self._step = step + self._latency = latency + if self._latency is None or self._latency == "min": + self._latency = self._step + elif self._latency == "max": + self._latency = self._duration + + self.tau_active = tau_active + self.rho_update = rho_update + self.delta_new = delta_new + self.gamma = gamma + self.beta = beta + self.max_speakers = max_speakers + + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) @property - def config(self) -> BasePipelineConfig: - raise NotImplementedError + def duration(self) -> float: + # Default duration is the one given by the segmentation model + if self._duration is None: + self._duration = self.segmentation.duration + return self._duration - def reset(self): - raise NotImplementedError + @property + def step(self) -> float: + return self._step - def set_timestamp_shift(self, shift: float): - raise NotImplementedError + @property + def latency(self) -> float: + return self._latency - def __call__( - self, - waveforms: Sequence[SlidingWindowFeature] - ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: - raise NotImplementedError + @property + def sample_rate(self) -> int: + # Expected sample rate is given by the segmentation model + if self._sample_rate is None: + self._sample_rate = self.segmentation.sample_rate + return self._sample_rate -class OnlineSpeakerDiarization(BasePipeline): - def __init__(self, config: Optional[PipelineConfig] = None): - self._config = PipelineConfig() if config is None else config +class SpeakerDiarization(base.Pipeline): + def __init__(self, config: SpeakerDiarizationConfig | None = None): + self._config = SpeakerDiarizationConfig() if config is None else config msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" assert self._config.step <= self._config.latency <= self._config.duration, msg - self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device) + self.segmentation = SpeakerSegmentation( + self._config.segmentation, self._config.device + ) self.embedding = OverlapAwareSpeakerEmbedding( - self._config.embedding, self._config.gamma, self._config.beta, norm=1, device=self._config.device + self._config.embedding, + self._config.gamma, + self._config.beta, + norm=1, + device=self._config.device, ) self.pred_aggregation = DelayedAggregation( self._config.step, @@ -67,10 +129,18 @@ def __init__(self, config: Optional[PipelineConfig] = None): @staticmethod def get_config_class() -> type: - return PipelineConfig + return SpeakerDiarizationConfig + + @staticmethod + def suggest_metric() -> BaseMetric: + return DiarizationErrorRate(collar=0, skip_overlap=False) + + @staticmethod + def hyper_parameters() -> Sequence[base.HyperParameter]: + return [base.TauActive, base.RhoUpdate, base.DeltaNew] @property - def config(self) -> PipelineConfig: + def config(self) -> SpeakerDiarizationConfig: return self._config def set_timestamp_shift(self, shift: float): @@ -88,9 +158,8 @@ def reset(self): self.chunk_buffer, self.pred_buffer = [], [] def __call__( - self, - waveforms: Sequence[SlidingWindowFeature] - ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: + self, waveforms: Sequence[SlidingWindowFeature] + ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" assert batch_size >= 1, msg @@ -98,13 +167,17 @@ def __call__( # Create batch from chunk sequence, shape (batch, samples, channels) batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) - expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) + expected_num_samples = int( + np.rint(self.config.duration * self.config.sample_rate) + ) msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" assert batch.shape[1] == expected_num_samples, msg # Extract segmentation and embeddings segmentations = self.segmentation(batch) # shape (batch, frames, speakers) - embeddings = self.embedding(batch, segmentations) # shape (batch, speakers, emb_dim) + embeddings = self.embedding( + batch, segmentations + ) # shape (batch, speakers, emb_dim) seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] @@ -133,7 +206,9 @@ def __call__( # Shift prediction timestamps if required if self.timestamp_shift != 0: shifted_agg_prediction = Annotation(agg_prediction.uri) - for segment, track, speaker in agg_prediction.itertracks(yield_label=True): + for segment, track, speaker in agg_prediction.itertracks( + yield_label=True + ): new_segment = Segment( segment.start + self.timestamp_shift, segment.end + self.timestamp_shift, diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py index 7aa31c05..5cd7c39e 100644 --- a/src/diart/blocks/embedding.py +++ b/src/diart/blocks/embedding.py @@ -22,12 +22,14 @@ def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None) def from_pyannote( model, use_hf_token: Union[Text, bool, None] = True, - device: Optional[torch.device] = None - ) -> 'SpeakerEmbedding': + device: Optional[torch.device] = None, + ) -> "SpeakerEmbedding": emb_model = EmbeddingModel.from_pyannote(model, use_hf_token) return SpeakerEmbedding(emb_model, device) - def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor: + def __call__( + self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None + ) -> torch.Tensor: """ Calculate speaker embeddings of input audio. If weights are given, calculate many speaker embeddings from the same waveform. @@ -58,7 +60,7 @@ def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeature self.model(inputs, weights), "(batch spk) feat -> batch spk feat", batch=batch_size, - spk=num_speakers + spk=num_speakers, ) else: output = self.model(inputs) @@ -76,6 +78,7 @@ class OverlappedSpeechPenalty: Temperature parameter (actually 1/beta) to lower joint speaker activations. Defaults to 10. """ + def __init__(self, gamma: float = 3, beta: float = 10): self.gamma = gamma self.beta = beta @@ -106,7 +109,11 @@ def __call__(self, embeddings: torch.Tensor) -> torch.Tensor: batch_size2, num_speakers2, _ = embeddings.shape assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2 with torch.no_grad(): - norm_embs = self.norm * embeddings / torch.norm(embeddings, p=2, dim=-1, keepdim=True) + norm_embs = ( + self.norm + * embeddings + / torch.norm(embeddings, p=2, dim=-1, keepdim=True) + ) return norm_embs @@ -131,6 +138,7 @@ class OverlapAwareSpeakerEmbedding: The device on which to run the embedding model. Defaults to GPU if available or CPU if not. """ + def __init__( self, model: EmbeddingModel, @@ -155,5 +163,7 @@ def from_pyannote( model = EmbeddingModel.from_pyannote(model, use_hf_token) return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device) - def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor: + def __call__( + self, waveform: TemporalFeatures, segmentation: TemporalFeatures + ) -> torch.Tensor: return self.normalize(self.embedding(waveform, self.osp(segmentation))) diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py index 8fda3ffc..e946c748 100644 --- a/src/diart/blocks/segmentation.py +++ b/src/diart/blocks/segmentation.py @@ -21,8 +21,8 @@ def __init__(self, model: SegmentationModel, device: Optional[torch.device] = No def from_pyannote( model, use_hf_token: Union[Text, bool, None] = True, - device: Optional[torch.device] = None - ) -> 'SpeakerSegmentation': + device: Optional[torch.device] = None, + ) -> "SpeakerSegmentation": seg_model = SegmentationModel.from_pyannote(model, use_hf_token) return SpeakerSegmentation(seg_model, device) @@ -40,6 +40,9 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: The batch dimension is omitted if waveform is a `SlidingWindowFeature`. """ with torch.no_grad(): - wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample") + wave = rearrange( + self.formatter.cast(waveform), + "batch sample channel -> batch channel sample", + ) output = self.model(wave.to(self.device)).cpu() return self.formatter.restore_type(output) diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py index 02594e3d..9c0afc3e 100644 --- a/src/diart/blocks/utils.py +++ b/src/diart/blocks/utils.py @@ -69,12 +69,21 @@ class Resample: resample_rate: int Sample rate of the output """ - def __init__(self, sample_rate: int, resample_rate: int): - self.resample = T.Resample(sample_rate, resample_rate) + + def __init__( + self, + sample_rate: int, + resample_rate: int, + device: Optional[torch.device] = None, + ): + self.device = device + if self.device is None: + self.device = torch.device("cpu") + self.resample = T.Resample(sample_rate, resample_rate).to(self.device) self.formatter = TemporalFeatureFormatter() def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: - wav = self.formatter.cast(waveform) # shape (batch, samples, 1) + wav = self.formatter.cast(waveform).to(self.device) # shape (batch, samples, 1) with torch.no_grad(): resampled_wav = self.resample(wav.transpose(-1, -2)).transpose(-1, -2) return self.formatter.restore_type(resampled_wav) @@ -90,6 +99,7 @@ class AdjustVolume: volume_in_db: float Target volume in dB. """ + def __init__(self, volume_in_db: float): self.target_db = volume_in_db self.formatter = TemporalFeatureFormatter() @@ -108,7 +118,9 @@ def get_volumes(waveforms: torch.Tensor) -> torch.Tensor: volumes: torch.Tensor Audio chunk volumes per channel. Shape (batch, 1, channels) """ - return 10 * torch.log10(torch.mean(torch.abs(waveforms) ** 2, dim=1, keepdim=True)) + return 10 * torch.log10( + torch.mean(torch.abs(waveforms) ** 2, dim=1, keepdim=True) + ) def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: wav = self.formatter.cast(waveform) # shape (batch, samples, channels) diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py new file mode 100644 index 00000000..0edd3e0b --- /dev/null +++ b/src/diart/blocks/vad.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from typing import Sequence + +import numpy as np +import torch +from pyannote.core import ( + Annotation, + Timeline, + SlidingWindowFeature, + SlidingWindow, + Segment, +) +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.detection import DetectionErrorRate +from typing_extensions import Literal + +from . import base +from .aggregation import DelayedAggregation +from .segmentation import SpeakerSegmentation +from .utils import Binarize +from .. import models as m +from .. import utils + + +class VoiceActivityDetectionConfig(base.PipelineConfig): + def __init__( + self, + segmentation: m.SegmentationModel | None = None, + duration: float | None = None, + step: float = 0.5, + latency: float | Literal["max", "min"] | None = None, + tau_active: float = 0.6, + device: torch.device | None = None, + **kwargs, + ): + # Default segmentation model is pyannote/segmentation + self.segmentation = segmentation or m.SegmentationModel.from_pyannote( + "pyannote/segmentation" + ) + + self._duration = duration + self._step = step + self._sample_rate: int | None = None + + # Latency defaults to the step duration + self._latency = latency + if self._latency is None or self._latency == "min": + self._latency = self._step + elif self._latency == "max": + self._latency = self._duration + + self.tau_active = tau_active + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + @property + def duration(self) -> float: + # Default duration is the one given by the segmentation model + if self._duration is None: + self._duration = self.segmentation.duration + return self._duration + + @property + def step(self) -> float: + return self._step + + @property + def latency(self) -> float: + return self._latency + + @property + def sample_rate(self) -> int: + # Expected sample rate is given by the segmentation model + if self._sample_rate is None: + self._sample_rate = self.segmentation.sample_rate + return self._sample_rate + + +class VoiceActivityDetection(base.Pipeline): + def __init__(self, config: VoiceActivityDetectionConfig | None = None): + self._config = VoiceActivityDetectionConfig() if config is None else config + + msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" + assert self._config.step <= self._config.latency <= self._config.duration, msg + + self.segmentation = SpeakerSegmentation( + self._config.segmentation, self._config.device + ) + self.pred_aggregation = DelayedAggregation( + self._config.step, + self._config.latency, + strategy="hamming", + cropping_mode="loose", + ) + self.audio_aggregation = DelayedAggregation( + self._config.step, + self._config.latency, + strategy="first", + cropping_mode="center", + ) + self.binarize = Binarize(self._config.tau_active) + + # Internal state, handle with care + self.timestamp_shift = 0 + self.chunk_buffer, self.pred_buffer = [], [] + + @staticmethod + def get_config_class() -> type: + return VoiceActivityDetectionConfig + + @staticmethod + def suggest_metric() -> BaseMetric: + return DetectionErrorRate(collar=0, skip_overlap=False) + + @staticmethod + def hyper_parameters() -> Sequence[base.HyperParameter]: + return [base.TauActive] + + @property + def config(self) -> base.PipelineConfig: + return self._config + + def reset(self): + self.set_timestamp_shift(0) + self.chunk_buffer, self.pred_buffer = [], [] + + def set_timestamp_shift(self, shift: float): + self.timestamp_shift = shift + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature], + ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: + batch_size = len(waveforms) + msg = "Pipeline expected at least 1 input" + assert batch_size >= 1, msg + + # Create batch from chunk sequence, shape (batch, samples, channels) + batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) + + expected_num_samples = int( + np.rint(self.config.duration * self.config.sample_rate) + ) + msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" + assert batch.shape[1] == expected_num_samples, msg + + # Extract segmentation + segmentations = self.segmentation(batch) # shape (batch, frames, speakers) + voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[ + 0 + ] # shape (batch, frames, 1) + + seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] + + outputs = [] + for wav, vad in zip(waveforms, voice_detection): + # Add timestamps to segmentation + sw = SlidingWindow( + start=wav.extent.start, + duration=seg_resolution, + step=seg_resolution, + ) + vad = SlidingWindowFeature(vad.cpu().numpy(), sw) + + # Update sliding buffer + self.chunk_buffer.append(wav) + self.pred_buffer.append(vad) + + # Aggregate buffer outputs for this time step + agg_waveform = self.audio_aggregation(self.chunk_buffer) + agg_prediction = self.pred_aggregation(self.pred_buffer) + agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False) + + # Shift prediction timestamps if required + if self.timestamp_shift != 0: + shifted_agg_prediction = Timeline(uri=agg_prediction.uri) + for segment in agg_prediction: + new_segment = Segment( + segment.start + self.timestamp_shift, + segment.end + self.timestamp_shift, + ) + shifted_agg_prediction.add(new_segment) + agg_prediction = shifted_agg_prediction + + # Convert timeline into annotation with single speaker "speech" + agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech")) + outputs.append((agg_prediction, agg_waveform)) + + # Make place for new chunks in buffer if required + if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows: + self.chunk_buffer = self.chunk_buffer[1:] + self.pred_buffer = self.pred_buffer[1:] + + return outputs diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index b6a3f9ff..b5a296d1 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -1,39 +1,116 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc import pandas as pd -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig +import torch + +from diart import argdoc +from diart import models as m +from diart import utils from diart.inference import Benchmark, Parallelize def run(): parser = argparse.ArgumentParser() - parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") - parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, - help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") - parser.add_argument("--embedding", default="pyannote/embedding", type=str, - help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") - parser.add_argument("--reference", type=Path, - help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") - parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") - parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3") - parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1") - parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") - parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") - parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") - parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") - parser.add_argument("--num-workers", default=0, type=int, - help=f"{argdoc.NUM_WORKERS}. Defaults to 0 (no parallelism)") - parser.add_argument("--cpu", dest="cpu", action="store_true", - help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing") - parser.add_argument("--hf-token", default="true", type=str, - help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") + parser.add_argument( + "root", + type=Path, + help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)", + ) + parser.add_argument( + "--pipeline", + default="SpeakerDiarization", + type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", + ) + parser.add_argument( + "--segmentation", + default="pyannote/segmentation", + type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", + ) + parser.add_argument( + "--embedding", + default="pyannote/embedding", + type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", + ) + parser.add_argument( + "--reference", + type=Path, + help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files", + ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" + ) + parser.add_argument( + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + ) + parser.add_argument( + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + ) + parser.add_argument( + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + ) + parser.add_argument( + "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" + ) + parser.add_argument( + "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" + ) + parser.add_argument( + "--max-speakers", + default=20, + type=int, + help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", + ) + parser.add_argument( + "--batch-size", + default=32, + type=int, + help=f"{argdoc.BATCH_SIZE}. Defaults to 32", + ) + parser.add_argument( + "--num-workers", + default=0, + type=int, + help=f"{argdoc.NUM_WORKERS}. Defaults to 0 (no parallelism)", + ) + parser.add_argument( + "--cpu", + dest="cpu", + action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", + ) + parser.add_argument( + "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing" + ) + parser.add_argument( + "--hf-token", + default="true", + type=str, + help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", + ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + + pipeline_class = utils.get_pipeline_class(args.pipeline) + benchmark = Benchmark( args.root, args.reference, @@ -43,11 +120,11 @@ def run(): batch_size=args.batch_size, ) - config = PipelineConfig.from_dict(vars(args)) + config = pipeline_class.get_config_class()(**vars(args)) if args.num_workers > 0: benchmark = Parallelize(benchmark, args.num_workers) - report = benchmark(OnlineSpeakerDiarization, config) + report = benchmark(pipeline_class, config) if args.output is not None and isinstance(report, pd.DataFrame): report.to_csv(args.output / "benchmark_report.csv") diff --git a/src/diart/console/client.py b/src/diart/console/client.py index 084dbc13..b3de36db 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -3,28 +3,25 @@ from threading import Thread from typing import Text, Optional -import diart.argdoc as argdoc -import diart.sources as src -import diart.utils as utils -import numpy as np import rx.operators as ops from websocket import WebSocket +from diart import argdoc +from diart import sources as src +from diart import utils + def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int): # Create audio source - block_size = int(np.rint(step * sample_rate)) source_components = source.split(":") if source_components[0] != "microphone": - audio_source = src.FileAudioSource(source, sample_rate) + audio_source = src.FileAudioSource(source, sample_rate, block_duration=step) else: device = int(source_components[1]) if len(source_components) > 1 else None - audio_source = src.MicrophoneAudioSource(sample_rate, block_size, device) + audio_source = src.MicrophoneAudioSource(step, device) # Encode audio, then send through websocket - audio_source.stream.pipe( - ops.map(utils.encode_audio) - ).subscribe_(ws.send) + audio_source.stream.pipe(ops.map(utils.encode_audio)).subscribe_(ws.send) # Start reading audio audio_source.read() @@ -41,18 +38,37 @@ def receive_audio(ws: WebSocket, output: Optional[Path]): def run(): parser = argparse.ArgumentParser() - parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'") + parser.add_argument( + "source", + type=str, + help="Path to an audio file | 'microphone' | 'microphone:'", + ) parser.add_argument("--host", required=True, type=str, help="Server host") parser.add_argument("--port", required=True, type=int, help="Server port") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("-sr", "--sample-rate", default=16000, type=int, help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000") - parser.add_argument("-o", "--output-file", type=Path, help="Output RTTM file. Defaults to no writing") + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "-sr", + "--sample-rate", + default=16000, + type=int, + help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000", + ) + parser.add_argument( + "-o", + "--output-file", + type=Path, + help="Output RTTM file. Defaults to no writing", + ) args = parser.parse_args() # Run websocket client ws = WebSocket() ws.connect(f"ws://{args.host}:{args.port}") - sender = Thread(target=send_audio, args=[ws, args.source, args.step, args.sample_rate]) + sender = Thread( + target=send_audio, args=[ws, args.source, args.step, args.sample_rate] + ) receiver = Thread(target=receive_audio, args=[ws, args.output_file]) sender.start() receiver.start() diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index 2f632d57..bc002e42 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -1,10 +1,13 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc -import diart.sources as src -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig -from diart.inference import RealTimeInference +import torch + +from diart import argdoc +from diart import models as m +from diart import sources as src +from diart import utils +from diart.inference import StreamingInference from diart.sinks import RTTMWriter @@ -12,34 +15,91 @@ def run(): parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host") parser.add_argument("--port", default=7007, type=int, help="Server port") - parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, - help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") - parser.add_argument("--embedding", default="pyannote/embedding", type=str, - help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") - parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") - parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3") - parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1") - parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") - parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") - parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") - parser.add_argument("--cpu", dest="cpu", action="store_true", - help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing") - parser.add_argument("--hf-token", default="true", type=str, - help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") + parser.add_argument( + "--pipeline", + default="SpeakerDiarization", + type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", + ) + parser.add_argument( + "--segmentation", + default="pyannote/segmentation", + type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", + ) + parser.add_argument( + "--embedding", + default="pyannote/embedding", + type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", + ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" + ) + parser.add_argument( + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + ) + parser.add_argument( + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + ) + parser.add_argument( + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + ) + parser.add_argument( + "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" + ) + parser.add_argument( + "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" + ) + parser.add_argument( + "--max-speakers", + default=20, + type=int, + help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", + ) + parser.add_argument( + "--cpu", + dest="cpu", + action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", + ) + parser.add_argument( + "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing" + ) + parser.add_argument( + "--hf-token", + default="true", + type=str, + help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", + ) args = parser.parse_args() - # Define online speaker diarization pipeline - config = PipelineConfig.from_dict(vars(args)) - pipeline = OnlineSpeakerDiarization(config) + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + + # Resolve pipeline + pipeline_class = utils.get_pipeline_class(args.pipeline) + config = pipeline_class.get_config_class()(**vars(args)) + pipeline = pipeline_class(config) # Create websocket audio source audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port) # Run online inference - inference = RealTimeInference( + inference = StreamingInference( pipeline, audio_source, batch_size=1, @@ -50,7 +110,9 @@ def run(): # Write to disk if required if args.output is not None: - inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")) + inference.attach_observers( + RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm") + ) # Send back responses as RTTM text lines inference.attach_hooks(lambda ann_wav: audio_source.send(ann_wav[0].to_rttm())) diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index d7218f07..713f3e99 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -1,57 +1,130 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc -import diart.sources as src -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig -from diart.inference import RealTimeInference +import torch + +from diart import argdoc +from diart import models as m +from diart import sources as src +from diart import utils +from diart.inference import StreamingInference from diart.sinks import RTTMWriter def run(): parser = argparse.ArgumentParser() - parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'") - parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, - help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") - parser.add_argument("--embedding", default="pyannote/embedding", type=str, - help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") - parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") - parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3") - parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1") - parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") - parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") - parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") - parser.add_argument("--no-plot", dest="no_plot", action="store_true", help="Skip plotting for faster inference") - parser.add_argument("--cpu", dest="cpu", action="store_true", - help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--output", type=str, - help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file") - parser.add_argument("--hf-token", default="true", type=str, - help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") + parser.add_argument( + "source", + type=str, + help="Path to an audio file | 'microphone' | 'microphone:'", + ) + parser.add_argument( + "--pipeline", + default="SpeakerDiarization", + type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", + ) + parser.add_argument( + "--segmentation", + default="pyannote/segmentation", + type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", + ) + parser.add_argument( + "--embedding", + default="pyannote/embedding", + type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", + ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" + ) + parser.add_argument( + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + ) + parser.add_argument( + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + ) + parser.add_argument( + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + ) + parser.add_argument( + "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" + ) + parser.add_argument( + "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" + ) + parser.add_argument( + "--max-speakers", + default=20, + type=int, + help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", + ) + parser.add_argument( + "--no-plot", + dest="no_plot", + action="store_true", + help="Skip plotting for faster inference", + ) + parser.add_argument( + "--cpu", + dest="cpu", + action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", + ) + parser.add_argument( + "--output", + type=str, + help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file", + ) + parser.add_argument( + "--hf-token", + default="true", + type=str, + help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", + ) args = parser.parse_args() - # Define online speaker diarization pipeline - config = PipelineConfig.from_dict(vars(args)) - pipeline = OnlineSpeakerDiarization(config) + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + + # Resolve pipeline + pipeline_class = utils.get_pipeline_class(args.pipeline) + config = pipeline_class.get_config_class()(**vars(args)) + pipeline = pipeline_class(config) # Manage audio source - block_size = config.optimal_block_size() source_components = args.source.split(":") if source_components[0] != "microphone": args.source = Path(args.source).expanduser() args.output = args.source.parent if args.output is None else Path(args.output) padding = config.get_file_padding(args.source) - audio_source = src.FileAudioSource(args.source, config.sample_rate, padding, block_size) + audio_source = src.FileAudioSource( + args.source, config.sample_rate, padding, config.step + ) pipeline.set_timestamp_shift(-padding[0]) else: - args.output = Path("~/").expanduser() if args.output is None else Path(args.output) + args.output = ( + Path("~/").expanduser() if args.output is None else Path(args.output) + ) device = int(source_components[1]) if len(source_components) > 1 else None - audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device) + audio_source = src.MicrophoneAudioSource(config.step, device) # Run online inference - inference = RealTimeInference( + inference = StreamingInference( pipeline, audio_source, batch_size=1, @@ -59,8 +132,13 @@ def run(): do_plot=not args.no_plot, show_progress=True, ) - inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")) - inference() + inference.attach_observers( + RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm") + ) + try: + inference() + except KeyboardInterrupt: + pass if __name__ == "__main__": diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index 4ad8852a..ec243348 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -1,54 +1,145 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc import optuna -from diart.blocks import PipelineConfig, OnlineSpeakerDiarization -from diart.optim import Optimizer, HyperParameter +import torch from optuna.samplers import TPESampler +from diart import argdoc +from diart import models as m +from diart import utils +from diart.blocks.base import HyperParameter +from diart.optim import Optimizer + def run(): parser = argparse.ArgumentParser() - parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") - parser.add_argument("--reference", required=True, type=str, - help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") - parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, - help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") - parser.add_argument("--embedding", default="pyannote/embedding", type=str, - help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") - parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") - parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") - parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3") - parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1") - parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") - parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") - parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") - parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") - parser.add_argument("--cpu", dest="cpu", action="store_true", - help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"), - help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new") - parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials") - parser.add_argument("--storage", type=str, - help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name") + parser.add_argument( + "root", + type=str, + help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)", + ) + parser.add_argument( + "--reference", + required=True, + type=str, + help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files", + ) + parser.add_argument( + "--pipeline", + default="SpeakerDiarization", + type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", + ) + parser.add_argument( + "--segmentation", + default="pyannote/segmentation", + type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", + ) + parser.add_argument( + "--embedding", + default="pyannote/embedding", + type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", + ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) + parser.add_argument( + "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" + ) + parser.add_argument( + "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" + ) + parser.add_argument( + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + ) + parser.add_argument( + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + ) + parser.add_argument( + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + ) + parser.add_argument( + "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" + ) + parser.add_argument( + "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" + ) + parser.add_argument( + "--max-speakers", + default=20, + type=int, + help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", + ) + parser.add_argument( + "--batch-size", + default=32, + type=int, + help=f"{argdoc.BATCH_SIZE}. Defaults to 32", + ) + parser.add_argument( + "--cpu", + dest="cpu", + action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", + ) + parser.add_argument( + "--hparams", + nargs="+", + default=("tau_active", "rho_update", "delta_new"), + help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new", + ) + parser.add_argument( + "--num-iter", default=100, type=int, help="Number of optimization trials" + ) + parser.add_argument( + "--storage", + type=str, + help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name", + ) parser.add_argument("--output", type=str, help="Working directory") - parser.add_argument("--hf-token", default="true", type=str, - help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") + parser.add_argument( + "--hf-token", + default="true", + type=str, + help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", + ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + + # Retrieve pipeline class + pipeline_class = utils.get_pipeline_class(args.pipeline) + # Create the base configuration for each trial - base_config = PipelineConfig.from_dict(vars(args)) + base_config = pipeline_class.get_config_class()(**vars(args)) # Create hyper-parameters to optimize + possible_hparams = pipeline_class.hyper_parameters() hparams = [HyperParameter.from_name(name) for name in args.hparams] + hparams = [hp for hp in hparams if hp in possible_hparams] + if not hparams: + print( + f"No hyper-parameters to optimize. " + f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}" + ) + exit(1) # Use a custom storage if given if args.output is not None: msg = "Both `output` and `storage` were set, but only one was expected" assert args.storage is None, msg - args.output = Path(args.output) + args.output = Path(args.output).expanduser() args.output.mkdir(parents=True, exist_ok=True) study_or_path = args.output elif args.storage is not None: @@ -60,11 +151,11 @@ def run(): # Run optimization Optimizer( + pipeline_class=pipeline_class, speech_path=args.root, reference_path=args.reference, study_or_path=study_or_path, batch_size=args.batch_size, - pipeline_class=OnlineSpeakerDiarization, hparams=hparams, base_config=base_config, )(num_iter=args.num_iter, show_progress=True) diff --git a/src/diart/features.py b/src/diart/features.py index 2489027a..2d5df672 100644 --- a/src/diart/features.py +++ b/src/diart/features.py @@ -1,4 +1,5 @@ from typing import Union, Optional +from abc import ABC, abstractmethod import numpy as np import torch @@ -7,15 +8,18 @@ TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor] -class TemporalFeatureFormatterState: +class TemporalFeatureFormatterState(ABC): """ Represents the recorded type of a temporal feature formatter. Its job is to transform temporal features into tensors and recover the original format on other features. """ + + @abstractmethod def to_tensor(self, features: TemporalFeatures) -> torch.Tensor: - raise NotImplementedError + pass + @abstractmethod def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: """ Cast `features` to the representing type and remove batch dimension if required. @@ -28,7 +32,7 @@ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: ------- new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim) """ - raise NotImplementedError + pass class SlidingWindowFeatureFormatterState(TemporalFeatureFormatterState): @@ -48,7 +52,9 @@ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: # Calculate resolution resolution = self.duration / num_frames # Temporal shift to keep track of current start time - resolution = SlidingWindow(start=self._cur_start_time, duration=resolution, step=resolution) + resolution = SlidingWindow( + start=self._cur_start_time, duration=resolution, step=resolution + ) return SlidingWindowFeature(features.squeeze(dim=0).cpu().numpy(), resolution) @@ -74,6 +80,7 @@ class TemporalFeatureFormatter: When casting temporal features as torch.Tensor, it remembers its type and format so it can lately restore it on other temporal features. """ + def __init__(self): self.state: Optional[TemporalFeatureFormatterState] = None diff --git a/src/diart/inference.py b/src/diart/inference.py index f4b65f5f..3eb72930 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -4,32 +4,33 @@ from traceback import print_exc from typing import Union, Text, Optional, Callable, Tuple, List -import diart.operators as dops -import diart.sources as src import numpy as np import pandas as pd import rx import rx.operators as ops import torch -from diart import utils -from diart.blocks import BasePipeline, Resample, BasePipelineConfig -from diart.progress import ProgressBar, RichProgressBar, TQDMProgressBar -from diart.sinks import DiarizationPredictionAccumulator, RealTimePlot, WindowClosedException from pyannote.core import Annotation, SlidingWindowFeature from pyannote.database.util import load_rttm -from pyannote.metrics.diarization import DiarizationErrorRate +from pyannote.metrics.base import BaseMetric from rx.core import Observer from tqdm import tqdm +from . import blocks +from . import operators as dops +from . import sources as src +from . import utils +from .progress import ProgressBar, RichProgressBar, TQDMProgressBar +from .sinks import PredictionAccumulator, StreamingPlot, WindowClosedException -class RealTimeInference: + +class StreamingInference: """Performs inference in real time given a pipeline and an audio source. Streams an audio source to an online speaker diarization pipeline. It allows users to attach a chain of operations in the form of hooks. Parameters ---------- - pipeline: BasePipeline + pipeline: StreamingPipeline Configured speaker diarization pipeline. source: AudioSource Audio source to be read and streamed. @@ -50,9 +51,10 @@ class RealTimeInference: If description is not provided, set to 'Streaming '. Defaults to RichProgressBar(). """ + def __init__( self, - pipeline: BasePipeline, + pipeline: blocks.Pipeline, source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, @@ -66,7 +68,7 @@ def __init__( self.do_profile = do_profile self.do_plot = do_plot self.show_progress = show_progress - self.accumulator = DiarizationPredictionAccumulator(self.source.uri) + self.accumulator = PredictionAccumulator(self.source.uri) self.unit = "chunk" if self.batch_size == 1 else "batch" self._observers = [] @@ -88,7 +90,7 @@ def __init__( self._pbar.create( total=self.num_chunks, description=f"Streaming {self.source.uri}", - unit=self.unit + unit=self.unit, ) # Initialize chronometer for profiling @@ -96,18 +98,32 @@ def __init__( self.stream = self.source.stream + # Rearrange stream to form sliding windows + self.stream = self.stream.pipe( + dops.rearrange_audio_stream( + chunk_duration, step_duration, source.sample_rate + ), + ) + # Dynamic resampling if the audio source isn't compatible if sample_rate != self.source.sample_rate: - msg = f"Audio source has sample rate {self.source.sample_rate}, " \ - f"but pipeline's is {sample_rate}. Will resample." + msg = ( + f"Audio source has sample rate {self.source.sample_rate}, " + f"but pipeline's is {sample_rate}. Will resample." + ) logging.warning(msg) self.stream = self.stream.pipe( - ops.map(Resample(self.source.sample_rate, sample_rate)) + ops.map( + blocks.Resample( + self.source.sample_rate, + sample_rate, + self.pipeline.config.device, + ) + ) ) - # Add rx operators to manage the inputs and outputs of the pipeline + # Form batches self.stream = self.stream.pipe( - dops.rearrange_audio_stream(chunk_duration, step_duration, sample_rate), ops.buffer_with_count(count=self.batch_size), ) @@ -140,7 +156,9 @@ def _close_chronometer(self): self._chrono.stop(do_count=False) self._chrono.report() - def attach_hooks(self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None]): + def attach_hooks( + self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None] + ): """Attach hooks to the pipeline. Parameters @@ -202,7 +220,7 @@ def __call__(self) -> Annotation: latency=config.latency, sample_rate=config.sample_rate, ), - ops.do(RealTimePlot(config.duration, config.latency)), + ops.do(StreamingPlot(config.duration, config.latency)), ) observable.subscribe( on_error=self._handle_error, @@ -246,6 +264,7 @@ class Benchmark: The performance between this two modes does not differ. Defaults to 32. """ + def __init__( self, speech_path: Union[Text, Path], @@ -288,7 +307,7 @@ def get_file_paths(self) -> List[Path]: def run_single( self, - pipeline: BasePipeline, + pipeline: blocks.Pipeline, filepath: Path, progress_bar: ProgressBar, ) -> Annotation: @@ -298,7 +317,7 @@ def run_single( Parameters ---------- - pipeline: BasePipeline + pipeline: StreamingPipeline Speaker diarization pipeline to run. filepath: Path Path to the target file. @@ -315,10 +334,10 @@ def run_single( filepath, pipeline.config.sample_rate, padding, - pipeline.config.optimal_block_size(), + pipeline.config.step, ) pipeline.set_timestamp_shift(-padding[0]) - inference = RealTimeInference( + inference = StreamingInference( pipeline, source, self.batch_size, @@ -337,7 +356,11 @@ def run_single( return pred - def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[Annotation]]: + def evaluate( + self, + predictions: List[Annotation], + metric: BaseMetric, + ) -> Union[pd.DataFrame, List[Annotation]]: """If a reference path was provided, compute the diarization error rate of a list of predictions. @@ -345,6 +368,8 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An ---------- predictions: List[Annotation] Predictions to evaluate. + metric: BaseMetric + Evaluation metric from pyannote.metrics. Returns ------- @@ -353,8 +378,7 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An reference path was given. Otherwise return the same predictions. """ if self.reference_path is not None: - metric = DiarizationErrorRate(collar=0, skip_overlap=False) - progress_bar = TQDMProgressBar("Computing DER", leave=False) + progress_bar = TQDMProgressBar(f"Computing {metric.name}", leave=False) progress_bar.create(total=len(predictions), unit="file") progress_bar.start() for hyp in predictions: @@ -368,18 +392,22 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An def __call__( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.PipelineConfig, + metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files. - Notice that the internal state of the pipeline is reset before benchmarking. + The internal state of the pipeline is reset before benchmarking. Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated by each worker. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. + metric: Optional[BaseMetric] + Evaluation metric from pyannote.metrics. + Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- @@ -400,7 +428,8 @@ def __call__( progress = TQDMProgressBar(desc, leave=False, do_close=True) predictions.append(self.run_single(pipeline, filepath, progress)) - return self.evaluate(predictions) + metric = pipeline.suggest_metric() if metric is None else metric + return self.evaluate(predictions, metric) class Parallelize: @@ -415,6 +444,7 @@ class Parallelize: Number of parallel workers. Defaults to 0 (no parallelism). """ + def __init__( self, benchmark: Benchmark, @@ -426,20 +456,20 @@ def __init__( def run_single_job( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.PipelineConfig, filepath: Path, description: Text, - ): + ) -> Annotation: """Build and run a pipeline on a single file. Configure execution to show progress alongside parallel runs. Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. filepath: Path Path to the target file. description: Text @@ -451,19 +481,22 @@ def run_single_job( Pipeline prediction for the given file. """ # The process ID inside the pool determines the position of the progress bar - idx_process = int(current_process().name.split('-')[1]) - 1 + idx_process = int(current_process().name.split("-")[1]) - 1 # TODO share models across processes # Instantiate a pipeline with the config pipeline = pipeline_class(config) # Create the progress bar for this job - progress = TQDMProgressBar(description, leave=False, position=idx_process, do_close=True) + progress = TQDMProgressBar( + description, leave=False, position=idx_process, do_close=True + ) # Run the pipeline return self.benchmark.run_single(pipeline, filepath, progress) def __call__( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.PipelineConfig, + metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files in parallel. Each worker will build and run the pipeline on a different file. @@ -471,10 +504,13 @@ def __call__( Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated by each worker. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. + metric: Optional[BaseMetric] + Evaluation metric from pyannote.metrics. + Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- @@ -488,12 +524,14 @@ def __call__( num_audio_files = len(audio_file_paths) # Workaround for multiprocessing with GPU - torch.multiprocessing.set_start_method('spawn') + torch.multiprocessing.set_start_method("spawn") # For Windows support freeze_support() # Create the pool of workers using a lock for parallel tqdm usage - pool = Pool(processes=self.num_workers, initargs=(RLock(),), initializer=tqdm.set_lock) + pool = Pool( + processes=self.num_workers, initargs=(RLock(),), initializer=tqdm.set_lock + ) # Determine the arguments for each job arg_list = [ ( @@ -512,4 +550,5 @@ def __call__( predictions = [job.get() for job in jobs] # Evaluate results - return self.benchmark.evaluate(predictions) + metric = pipeline_class.suggest_metric() if metric is None else metric + return self.benchmark.evaluate(predictions, metric) diff --git a/src/diart/mapping.py b/src/diart/mapping.py index 3023da4d..847d0f78 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -1,13 +1,14 @@ from __future__ import annotations from typing import Callable, Iterable, List, Optional, Text, Tuple, Union, Dict +from abc import ABC, abstractmethod import numpy as np from pyannote.core.utils.distance import cdist from scipy.optimize import linear_sum_assignment as lsap -class MappingMatrixObjective: +class MappingMatrixObjective(ABC): def invalid_tensor(self, shape: Union[Tuple, int]) -> np.ndarray: return np.ones(shape) * self.invalid_value @@ -51,16 +52,19 @@ def invalid_value(self) -> float: return -1e10 if self.maximize else 1e10 @property + @abstractmethod def maximize(self) -> bool: - raise NotImplementedError() + pass @property + @abstractmethod def best_possible_value(self) -> float: - raise NotImplementedError() + pass @property + @abstractmethod def best_value_fn(self) -> Callable: - raise NotImplementedError() + pass class MinimizationObjective(MappingMatrixObjective): diff --git a/src/diart/models.py b/src/diart/models.py index df66e166..5577a097 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Optional, Text, Union, Callable import torch @@ -5,6 +6,7 @@ try: import pyannote.audio.pipelines.utils as pyannote_loader + _has_pyannote = True except ImportError: _has_pyannote = False @@ -20,7 +22,7 @@ def __call__(self) -> nn.Module: return pyannote_loader.get_model(self.model_info, self.hf_token) -class LazyModel(nn.Module): +class LazyModel(nn.Module, ABC): def __init__(self, loader: Callable[[], nn.Module]): super().__init__() self.get_model = loader @@ -47,8 +49,11 @@ class SegmentationModel(LazyModel): """ Minimal interface for a segmentation model. """ + @staticmethod - def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'SegmentationModel': + def from_pyannote( + model, use_hf_token: Union[Text, bool, None] = True + ) -> "SegmentationModel": """ Returns a `SegmentationModel` wrapping a pyannote model. @@ -69,13 +74,16 @@ def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'Segme return PyannoteSegmentationModel(model, use_hf_token) @property + @abstractmethod def sample_rate(self) -> int: - raise NotImplementedError + pass @property + @abstractmethod def duration(self) -> float: - raise NotImplementedError + pass + @abstractmethod def forward(self, waveform: torch.Tensor) -> torch.Tensor: """ Forward pass of the segmentation model. @@ -88,7 +96,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: ------- speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) """ - raise NotImplementedError + pass class PyannoteSegmentationModel(SegmentationModel): @@ -111,8 +119,11 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: class EmbeddingModel(LazyModel): """Minimal interface for an embedding model.""" + @staticmethod - def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'EmbeddingModel': + def from_pyannote( + model, use_hf_token: Union[Text, bool, None] = True + ) -> "EmbeddingModel": """ Returns an `EmbeddingModel` wrapping a pyannote model. @@ -132,10 +143,9 @@ def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'Embed assert _has_pyannote, "No pyannote.audio installation found" return PyannoteEmbeddingModel(model, use_hf_token) + @abstractmethod def forward( - self, - waveform: torch.Tensor, - weights: Optional[torch.Tensor] = None + self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Forward pass of an embedding model with optional weights. @@ -150,7 +160,7 @@ def forward( ------- speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) """ - raise NotImplementedError + pass class PyannoteEmbeddingModel(EmbeddingModel): diff --git a/src/diart/operators.py b/src/diart/operators.py index 6d73fc9d..7ce13285 100644 --- a/src/diart/operators.py +++ b/src/diart/operators.py @@ -25,20 +25,24 @@ def initial(): def has_samples(num_samples: int): def call_fn(state) -> bool: return state.chunk is not None and state.chunk.shape[1] == num_samples + return call_fn @staticmethod def to_sliding_window(sample_rate: int): def call_fn(state) -> SlidingWindowFeature: - resolution = SlidingWindow(start=state.start_time, duration=1. / sample_rate, step=1. / sample_rate) + resolution = SlidingWindow( + start=state.start_time, + duration=1.0 / sample_rate, + step=1.0 / sample_rate, + ) return SlidingWindowFeature(state.chunk.T, resolution) + return call_fn def rearrange_audio_stream( - duration: float = 5, - step: float = 0.5, - sample_rate: int = 16000 + duration: float = 5, step: float = 0.5, sample_rate: int = 16000 ) -> Operator: chunk_samples = int(round(sample_rate * duration)) step_samples = int(round(sample_rate * step)) @@ -49,11 +53,17 @@ def rearrange_audio_stream( def accumulate(state: AudioBufferState, value: np.ndarray): # State contains the last emitted chunk, the current step buffer and the last starting time if value.ndim != 2 or value.shape[0] != 1: - raise ValueError(f"Waveform must have shape (1, samples) but {value.shape} was found") + raise ValueError( + f"Waveform must have shape (1, samples) but {value.shape} was found" + ) start_time = state.start_time # Add new samples to the buffer - buffer = value if state.buffer is None else np.concatenate([state.buffer, value], axis=1) + buffer = ( + value + if state.buffer is None + else np.concatenate([state.buffer, value], axis=1) + ) # Check for buffer overflow if buffer.shape[1] >= step_samples: @@ -86,7 +96,7 @@ def accumulate(state: AudioBufferState, value: np.ndarray): ops.filter(AudioBufferState.has_samples(chunk_samples)), ops.filter(lambda state: state.changed), # Transform state into a SlidingWindowFeature containing the new chunk - ops.map(AudioBufferState.to_sliding_window(sample_rate)) + ops.map(AudioBufferState.to_sliding_window(sample_rate)), ) @@ -96,6 +106,7 @@ def accumulate(state: List[Any], value: Any) -> List[Any]: if len(new_state) > n: return new_state[1:] return new_state + return rx.pipe(ops.scan(accumulate, [])) @@ -117,17 +128,19 @@ class OutputAccumulationState: next_sample: Optional[int] @staticmethod - def initial() -> 'OutputAccumulationState': + def initial() -> "OutputAccumulationState": return OutputAccumulationState(None, None, 0, 0) @property def cropped_waveform(self) -> SlidingWindowFeature: return SlidingWindowFeature( - self.waveform[:self.next_sample], + self.waveform[: self.next_sample], self.waveform.sliding_window, ) - def to_tuple(self) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]: + def to_tuple( + self, + ) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]: return self.annotation, self.cropped_waveform, self.real_time @@ -153,9 +166,10 @@ def accumulate_output( ------- A reactive x operator implementing this behavior. """ + def accumulate( state: OutputAccumulationState, - value: Tuple[Annotation, Optional[SlidingWindowFeature]] + value: Tuple[Annotation, Optional[SlidingWindowFeature]], ) -> OutputAccumulationState: value = PredictionWithAudio(*value) annotation, waveform = None, None @@ -187,7 +201,7 @@ def accumulate( (state.waveform.data, np.zeros_like(state.waveform.data)), axis=0 ) # Copy chunk into buffer - waveform[state.next_sample:new_next_sample] = value.waveform.data + waveform[state.next_sample : new_next_sample] = value.waveform.data waveform = SlidingWindowFeature(waveform, sw_holder.waveform.sliding_window) return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) @@ -234,14 +248,14 @@ def buffer_output( def accumulate( state: OutputAccumulationState, - value: Tuple[Annotation, Optional[SlidingWindowFeature]] + value: Tuple[Annotation, Optional[SlidingWindowFeature]], ) -> OutputAccumulationState: value = PredictionWithAudio(*value) annotation, waveform = None, None # Determine the real time of the stream and the start time of the buffer real_time = duration if state.annotation is None else state.real_time + step - start_time = max(0., real_time - latency - duration) + start_time = max(0.0, real_time - latency - duration) # Update annotation and constrain its bounds to the buffer if state.annotation is None: @@ -267,7 +281,7 @@ def accumulate( elif state.next_sample <= num_samples: # The buffer isn't full, copy into next free buffer chunk waveform = state.waveform.data - waveform[state.next_sample:new_next_sample] = value.waveform.data + waveform[state.next_sample : new_next_sample] = value.waveform.data else: # The buffer is full, shift values to the left and copy into last buffer chunk waveform = np.roll(state.waveform.data, -num_step_samples, axis=0) @@ -277,7 +291,9 @@ def accumulate( waveform[-num_step_samples:] = value.waveform.data[:num_step_samples] # Wrap waveform in a sliding window feature to include timestamps - window = SlidingWindow(start=start_time, duration=resolution, step=resolution) + window = SlidingWindow( + start=start_time, duration=resolution, step=resolution + ) waveform = SlidingWindowFeature(waveform, window) return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) diff --git a/src/diart/optim.py b/src/diart/optim.py index 05800a05..ca61d744 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -1,51 +1,32 @@ from collections import OrderedDict -from dataclasses import dataclass from pathlib import Path from typing import Sequence, Text, Optional, Union from optuna import TrialPruned, Study, create_study from optuna.samplers import TPESampler from optuna.trial import Trial, FrozenTrial +from pyannote.metrics.base import BaseMetric from tqdm import trange, tqdm +from typing_extensions import Literal +from . import blocks from .audio import FilePath -from .blocks import BasePipelineConfig, PipelineConfig, OnlineSpeakerDiarization from .inference import Benchmark -@dataclass -class HyperParameter: - name: Text - low: float - high: float - - @staticmethod - def from_name(name: Text) -> 'HyperParameter': - if name == "tau_active": - return TauActive - if name == "rho_update": - return RhoUpdate - if name == "delta_new": - return DeltaNew - raise ValueError(f"Hyper-parameter '{name}' not recognized") - - -TauActive = HyperParameter("tau_active", low=0, high=1) -RhoUpdate = HyperParameter("rho_update", low=0, high=1) -DeltaNew = HyperParameter("delta_new", low=0, high=2) - - class Optimizer: def __init__( self, + pipeline_class: type, speech_path: Union[Text, Path], reference_path: Union[Text, Path], study_or_path: Union[FilePath, Study], batch_size: int = 32, - pipeline_class: type = OnlineSpeakerDiarization, - hparams: Optional[Sequence[HyperParameter]] = None, - base_config: Optional[BasePipelineConfig] = None, + hparams: Optional[Sequence[blocks.base.HyperParameter]] = None, + base_config: Optional[blocks.PipelineConfig] = None, do_kickstart_hparams: bool = True, + metric: Optional[BaseMetric] = None, + direction: Literal["minimize", "maximize"] = "minimize", ): self.pipeline_class = pipeline_class # FIXME can we run this benchmark in parallel? @@ -58,21 +39,25 @@ def __init__( batch_size=batch_size, ) + self.metric = metric + self.direction = direction self.base_config = base_config self.do_kickstart_hparams = do_kickstart_hparams if self.base_config is None: - self.base_config = PipelineConfig() + self.base_config = self.pipeline_class.get_config_class()() self.do_kickstart_hparams = False self.hparams = hparams if self.hparams is None: - self.hparams = [TauActive, RhoUpdate, DeltaNew] + self.hparams = self.pipeline_class.hyper_parameters() # Make sure hyper-parameters exist in the configuration class given possible_hparams = vars(self.base_config) for param in self.hparams: - msg = f"Hyper-parameter {param.name} not found " \ - f"in configuration {self.base_config.__class__.__name__}" + msg = ( + f"Hyper-parameter {param.name} not found " + f"in configuration {self.base_config.__class__.__name__}" + ) assert param.name in possible_hparams, msg self._progress: Optional[tqdm] = None @@ -85,7 +70,7 @@ def __init__( storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"), sampler=TPESampler(), study_name=study_or_path.stem, - direction="minimize", + direction=self.direction, load_if_exists=True, ) else: @@ -105,7 +90,7 @@ def _callback(self, study: Study, trial: FrozenTrial): return self._progress.update(1) self._progress.set_description(f"Trial {trial.number + 1}") - values = {"best_der": study.best_value} + values = {"best_perf": study.best_value} for name, value in study.best_params.items(): values[f"best_{name}"] = value self._progress.set_postfix(OrderedDict(values)) @@ -125,11 +110,16 @@ def objective(self, trial: Trial) -> float: # Instantiate the new configuration for the trial config = self.base_config.__class__(**trial_config) + # Determine the evaluation metric + metric = self.metric + if metric is None: + metric = self.pipeline_class.suggest_metric() + # Run pipeline over the dataset - report = self.benchmark(self.pipeline_class, config) + report = self.benchmark(self.pipeline_class, config, metric) - # Extract DER from report - return report.loc["TOTAL", "diarization error rate"]["%"] + # Extract target metric from report + return report.loc["TOTAL", metric.name]["%"] def __call__(self, num_iter: int, show_progress: bool = True): self._progress = None @@ -141,8 +131,11 @@ def __call__(self, num_iter: int, show_progress: bool = True): self._progress.set_description(f"Trial {last_trial + 1}") # Start with base config hyper-parameters if config was given if self.do_kickstart_hparams: - self.study.enqueue_trial({ - param.name: getattr(self.base_config, param.name) - for param in self.hparams - }, skip_if_exists=True) + self.study.enqueue_trial( + { + param.name: getattr(self.base_config, param.name) + for param in self.hparams + }, + skip_if_exists=True, + ) self.study.optimize(self.objective, num_iter, callbacks=[self._callback]) diff --git a/src/diart/progress.py b/src/diart/progress.py index 240cc628..d29080f9 100644 --- a/src/diart/progress.py +++ b/src/diart/progress.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Optional, Text import rich @@ -5,32 +6,46 @@ from tqdm import tqdm -class ProgressBar: - def create(self, total: int, description: Optional[Text] = None, unit: Text = "it", **kwargs): - raise NotImplementedError +class ProgressBar(ABC): + @abstractmethod + def create( + self, + total: int, + description: Optional[Text] = None, + unit: Text = "it", + **kwargs, + ): + pass + @abstractmethod def start(self): - raise NotImplementedError + pass + @abstractmethod def update(self, n: int = 1): - raise NotImplementedError + pass + @abstractmethod def write(self, text: Text): - raise NotImplementedError + pass + @abstractmethod def stop(self): - raise NotImplementedError + pass + @abstractmethod def close(self): - raise NotImplementedError + pass @property + @abstractmethod def default_description(self) -> Text: - raise NotImplementedError + pass @property + @abstractmethod def initial_description(self) -> Optional[Text]: - raise NotImplementedError + pass def resolve_description(self, new_description: Optional[Text] = None) -> Text: if self.initial_description is None: @@ -66,7 +81,13 @@ def initial_description(self) -> Optional[Text]: return f"[{self.color}]{self.description}" return self.description - def create(self, total: int, description: Optional[Text] = None, unit: Text = "it", **kwargs): + def create( + self, + total: int, + description: Optional[Text] = None, + unit: Text = "it", + **kwargs, + ): if self.task_id is None: self.task_id = self.bar.add_task( self.resolve_description(f"[{self.color}]{description}"), @@ -74,7 +95,7 @@ def create(self, total: int, description: Optional[Text] = None, unit: Text = "i total=total, completed=0, visible=True, - **kwargs + **kwargs, ) def start(self): @@ -103,7 +124,7 @@ def __init__( description: Optional[Text] = None, leave: bool = True, position: Optional[int] = None, - do_close: bool = True + do_close: bool = True, ): self.description = description self.leave = leave @@ -119,7 +140,13 @@ def default_description(self) -> Text: def initial_description(self) -> Optional[Text]: return self.description - def create(self, total: int, description: Optional[Text] = None, unit: Optional[Text] = "it", **kwargs): + def create( + self, + total: int, + description: Optional[Text] = None, + unit: Optional[Text] = "it", + **kwargs, + ): if self.pbar is None: self.pbar = tqdm( desc=self.resolve_description(description), @@ -127,7 +154,7 @@ def create(self, total: int, description: Optional[Text] = None, unit: Optional[ unit=unit, leave=self.leave, position=self.position, - **kwargs + **kwargs, ) def start(self): diff --git a/src/diart/sinks.py b/src/diart/sinks.py index cf480bed..ed4e2ea0 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -8,12 +8,14 @@ from rx.core import Observer from typing_extensions import Literal +from . import utils + class WindowClosedException(Exception): pass -def _extract_annotation(value: Union[Tuple, Annotation]) -> Annotation: +def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation: if isinstance(value, tuple): return value[0] if isinstance(value, Annotation): @@ -39,14 +41,15 @@ def patch(self): if annotations: annotation = annotations[0] annotation.uri = self.uri - with open(self.path, 'w') as file: + with open(self.path, "w") as file: annotation.support(self.patch_collar).write_rttm(file) def on_next(self, value: Union[Tuple, Annotation]): - annotation = _extract_annotation(value) - annotation.uri = self.uri - with open(self.path, 'a') as file: - annotation.write_rttm(file) + prediction = _extract_prediction(value) + # Write prediction in RTTM format + prediction.uri = self.uri + with open(self.path, "a") as file: + prediction.write_rttm(file) def on_error(self, error: Exception): self.patch() @@ -55,30 +58,30 @@ def on_completed(self): self.patch() -class DiarizationPredictionAccumulator(Observer): +class PredictionAccumulator(Observer): def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05): super().__init__() self.uri = uri self.patch_collar = patch_collar - self._annotation = None + self._prediction: Optional[Annotation] = None def patch(self): """Stitch same-speaker turns that are close to each other""" - if self._annotation is not None: - self._annotation = self._annotation.support(self.patch_collar) + if self._prediction is not None: + self._prediction = self._prediction.support(self.patch_collar) def get_prediction(self) -> Annotation: # Patch again in case this is called before on_completed self.patch() - return self._annotation + return self._prediction def on_next(self, value: Union[Tuple, Annotation]): - annotation = _extract_annotation(value) - annotation.uri = self.uri - if self._annotation is None: - self._annotation = annotation + prediction = _extract_prediction(value) + prediction.uri = self.uri + if self._prediction is None: + self._prediction = prediction else: - self._annotation.update(annotation) + self._prediction.update(prediction) def on_error(self, error: Exception): self.patch() @@ -87,7 +90,7 @@ def on_completed(self): self.patch() -class RealTimePlot(Observer): +class StreamingPlot(Observer): def __init__( self, duration: float, @@ -118,10 +121,12 @@ def _init_num_axs(self): def _init_figure(self): self._init_num_axs() - self.figure, self.axs = plt.subplots(self.num_axs, 1, figsize=(10, 2 * self.num_axs)) + self.figure, self.axs = plt.subplots( + self.num_axs, 1, figsize=(10, 2 * self.num_axs) + ) if self.num_axs == 1: self.axs = [self.axs] - self.figure.canvas.mpl_connect('close_event', self._on_window_closed) + self.figure.canvas.mpl_connect("close_event", self._on_window_closed) def _clear_axs(self): for i in range(self.num_axs): @@ -131,7 +136,7 @@ def get_plot_bounds(self, real_time: float) -> Segment: start_time = 0 end_time = real_time - self.latency if self.visualization == "slide": - start_time = max(0., end_time - self.window_duration) + start_time = max(0.0, end_time - self.window_duration) return Segment(start_time, end_time) def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): @@ -139,6 +144,7 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): raise WindowClosedException prediction, waveform, real_time = values + # Initialize figure if first call if self.figure is None: self._init_figure() @@ -147,15 +153,21 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): # Set plot bounds notebook.crop = self.get_plot_bounds(real_time) - # Plot current values + # Align prediction and reference if possible if self.reference is not None: metric = DiarizationErrorRate() mapping = metric.optimal_mapping(self.reference, prediction) prediction.rename_labels(mapping=mapping, copy=False) + + # Plot prediction notebook.plot_annotation(prediction, self.axs[0]) self.axs[0].set_title("Output") + + # Plot waveform notebook.plot_feature(waveform, self.axs[1]) self.axs[1].set_title("Audio") + + # Plot reference if available if self.num_axs == 3: notebook.plot_annotation(self.reference, self.axs[2]) self.axs[2].set_title("Reference") diff --git a/src/diart/sources.py b/src/diart/sources.py index 0f5dedf7..82051b2e 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from pathlib import Path from queue import SimpleQueue from typing import Text, Optional, AnyStr, Dict, Any, Union, Tuple @@ -5,16 +6,16 @@ import numpy as np import sounddevice as sd import torch -from diart import utils from einops import rearrange from rx.subject import Subject from torchaudio.io import StreamReader from websocket_server import WebsocketServer +from . import utils from .audio import FilePath, AudioLoader -class AudioSource: +class AudioSource(ABC): """Represents a source of audio that can start streaming via the `stream` property. Parameters @@ -24,6 +25,7 @@ class AudioSource: sample_rate: int Sample rate of the audio source. """ + def __init__(self, uri: Text, sample_rate: int): self.uri = uri self.sample_rate = sample_rate @@ -34,13 +36,15 @@ def duration(self) -> Optional[float]: """The duration of the stream if known. Defaults to None (unknown duration).""" return None + @abstractmethod def read(self): """Start reading the source and yielding samples through the stream.""" - raise NotImplementedError + pass + @abstractmethod def close(self): """Stop reading the source and close all open streams.""" - raise NotImplementedError + pass class FileAudioSource(AudioSource): @@ -55,23 +59,24 @@ class FileAudioSource(AudioSource): padding: (float, float) Left and right padding to add to the file (in seconds). Defaults to (0, 0). - block_size: int - Number of samples per chunk emitted. - Defaults to 1000. + block_duration: int + Duration of each emitted chunk in seconds. + Defaults to 0.5 seconds. """ + def __init__( self, file: FilePath, sample_rate: int, padding: Tuple[float, float] = (0, 0), - block_size: int = 1000, + block_duration: float = 0.5, ): super().__init__(Path(file).stem, sample_rate) self.loader = AudioLoader(self.sample_rate, mono=True) self._duration = self.loader.get_duration(file) self.file = file self.resolution = 1 / self.sample_rate - self.block_size = block_size + self.block_size = int(np.rint(block_duration * self.sample_rate)) self.padding_start, self.padding_end = padding self.is_closed = False @@ -105,9 +110,13 @@ def read(self): # Add last incomplete chunk with padding if num_samples % self.block_size != 0: - last_chunk = waveform[:, chunks.shape[0] * self.block_size:].unsqueeze(0).numpy() + last_chunk = ( + waveform[:, chunks.shape[0] * self.block_size :].unsqueeze(0).numpy() + ) diff_samples = self.block_size - last_chunk.shape[-1] - last_chunk = np.concatenate([last_chunk, np.zeros((1, 1, diff_samples))], axis=-1) + last_chunk = np.concatenate( + [last_chunk, np.zeros((1, 1, diff_samples))], axis=-1 + ) chunks = np.vstack([chunks, last_chunk]) # Stream blocks @@ -131,11 +140,9 @@ class MicrophoneAudioSource(AudioSource): Parameters ---------- - sample_rate: int - Sample rate for the emitted audio chunks. - block_size: int - Number of samples per chunk emitted. - Defaults to 1000. + block_duration: int + Duration of each emitted chunk in seconds. + Defaults to 0.5 seconds. device: int | str | (int, str) | None Device identifier compatible for the sounddevice stream. If None, use the default device. @@ -144,15 +151,27 @@ class MicrophoneAudioSource(AudioSource): def __init__( self, - sample_rate: int, - block_size: int = 1000, + block_duration: float = 0.5, device: Optional[Union[int, Text, Tuple[int, Text]]] = None, ): - super().__init__("live_recording", sample_rate) - self.block_size = block_size + # Use the lowest supported sample rate + sample_rates = [16000, 32000, 44100, 48000] + best_sample_rate = None + for sr in sample_rates: + try: + sd.check_input_settings(device=device, samplerate=sr) + except Exception: + pass + else: + best_sample_rate = sr + break + super().__init__(f"input_device:{device}", best_sample_rate) + + # Determine block size in samples and create input stream + self.block_size = int(np.rint(block_duration * self.sample_rate)) self._mic_stream = sd.InputStream( channels=1, - samplerate=sample_rate, + samplerate=self.sample_rate, latency=0, blocksize=self.block_size, callback=self._read_callback, @@ -202,13 +221,14 @@ class WebSocketAudioSource(AudioSource): Path to a certificate if using SSL. Defaults to no certificate. """ + def __init__( self, sample_rate: int, host: Text = "127.0.0.1", port: int = 7007, key: Optional[Union[Text, Path]] = None, - certificate: Optional[Union[Text, Path]] = None + certificate: Optional[Union[Text, Path]] = None, ): # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities. # I would prefer the client to send a JSON with data and sample rate, then resample if needed @@ -258,10 +278,10 @@ def __init__( sample_rate: int, streamer: StreamReader, stream_index: Optional[int] = None, - block_size: int = 1000, + block_duration: float = 0.5, ): super().__init__(uri, sample_rate) - self.block_size = block_size + self.block_size = int(np.rint(block_duration * self.sample_rate)) self._streamer = streamer self._streamer.add_basic_audio_stream( frames_per_chunk=self.block_size, @@ -287,3 +307,16 @@ def read(self): def close(self): self.is_closed = True + + +class AppleDeviceAudioSource(TorchStreamAudioSource): + def __init__( + self, + sample_rate: int, + device: str = "0:0", + stream_index: int = 0, + block_duration: float = 0.5, + ): + uri = f"apple_input_device:{device}" + streamer = StreamReader(device, format="avfoundation") + super().__init__(uri, sample_rate, streamer, stream_index, block_duration) diff --git a/src/diart/utils.py b/src/diart/utils.py index e90861c7..ca27d022 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -1,12 +1,14 @@ import base64 import time -from typing import Optional, Text, Union, Any, Dict +from typing import Optional, Text, Union import matplotlib.pyplot as plt import numpy as np -from diart.progress import ProgressBar from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook +from . import blocks +from .progress import ProgressBar + class Chronometer: def __init__(self, unit: Text, progress_bar: Optional[ProgressBar] = None): @@ -51,10 +53,6 @@ def parse_hf_token_arg(hf_token: Union[bool, Text]) -> Union[bool, Text]: return hf_token -def get(data: Dict[Text, Any], key: Text, default: Any) -> Any: - return data[key] if key in data else default - - def encode_audio(waveform: np.ndarray) -> Text: data = waveform.astype(np.float32).tobytes() return base64.b64encode(data).decode("utf-8") @@ -74,6 +72,18 @@ def get_padding_left(stream_duration: float, chunk_duration: float) -> float: return 0 +def repeat_label(label: Text): + while True: + yield label + + +def get_pipeline_class(class_name: Text) -> type: + pipeline_class = getattr(blocks, class_name, None) + msg = f"Pipeline '{class_name}' doesn't exist" + assert pipeline_class is not None, msg + return pipeline_class + + def get_padding_right(latency: float, step: float) -> float: return latency - step @@ -88,6 +98,7 @@ def apply(feature: SlidingWindowFeature): notebook.plot_feature(feature) plt.tight_layout() plt.show() + return apply @@ -102,4 +113,5 @@ def apply(annotation: Annotation): notebook.plot_annotation(annotation) plt.tight_layout() plt.show() + return apply