diff --git a/README.md b/README.md
index ef533946..f9d87b91 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,7 @@
|
- 🤖 Custom models
+ 🤖 Add your model
|
@@ -64,17 +64,11 @@
1) Create environment:
```shell
-conda create -n diart python=3.8
+conda env create -f diart/environment.yml
conda activate diart
```
-2) Install audio libraries:
-
-```shell
-conda install portaudio pysoundfile ffmpeg -c conda-forge
-```
-
-3) Install diart:
+2) Install the package:
```shell
pip install diart
```
@@ -110,32 +104,32 @@ See `diart.stream -h` for more options.
### From python
-Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk:
+Use `StreamingInference` to run a pipeline on an audio source and write the results to disk:
```python
-from diart import OnlineSpeakerDiarization
+from diart import SpeakerDiarization
from diart.sources import MicrophoneAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
from diart.sinks import RTTMWriter
-pipeline = OnlineSpeakerDiarization()
-mic = MicrophoneAudioSource(pipeline.config.sample_rate)
-inference = RealTimeInference(pipeline, mic, do_plot=True)
+pipeline = SpeakerDiarization()
+mic = MicrophoneAudioSource()
+inference = StreamingInference(pipeline, mic, do_plot=True)
inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
prediction = inference()
```
-For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).
+For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#-reproducibility)).
-## 🤖 Custom models
+## 🤖 Add your model
-Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses):
+Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):
```python
-from diart import OnlineSpeakerDiarization, PipelineConfig
+from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import EmbeddingModel, SegmentationModel
from diart.sources import MicrophoneAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
def model_loader():
@@ -168,19 +162,19 @@ class MyEmbeddingModel(EmbeddingModel):
return self.model(waveform, weights)
-config = PipelineConfig(
+config = SpeakerDiarizationConfig(
segmentation=MySegmentationModel(),
embedding=MyEmbeddingModel()
)
-pipeline = OnlineSpeakerDiarization(config)
-mic = MicrophoneAudioSource(config.sample_rate)
-inference = RealTimeInference(pipeline, mic)
+pipeline = SpeakerDiarization(config)
+mic = MicrophoneAudioSource()
+inference = StreamingInference(pipeline, mic)
prediction = inference()
```
## 📈 Tune hyper-parameters
-Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.
+Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs.
### From the command line
@@ -247,12 +241,11 @@ from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding")
-sample_rate = segmentation.model.sample_rate
-mic = MicrophoneAudioSource(sample_rate)
+mic = MicrophoneAudioSource()
stream = mic.stream.pipe(
# Reformat stream to 5s duration and 500ms shift
- dops.rearrange_audio_stream(sample_rate=sample_rate),
+ dops.rearrange_audio_stream(sample_rate=segmentation.model.sample_rate),
ops.map(lambda wav: (wav, segmentation(wav))),
ops.starmap(embedding)
).subscribe(on_next=lambda emb: print(emb.shape))
@@ -281,7 +274,7 @@ diart.serve --host 0.0.0.0 --port 7007
diart.client microphone --host --port 7007
```
-**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
+**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
See `-h` for more options.
@@ -290,13 +283,13 @@ See `-h` for more options.
For customized solutions, a server can also be created in python using the `WebSocketAudioSource`:
```python
-from diart import OnlineSpeakerDiarization
+from diart import SpeakerDiarization
from diart.sources import WebSocketAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
-pipeline = OnlineSpeakerDiarization()
+pipeline = SpeakerDiarization()
source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
-inference = RealTimeInference(pipeline, source)
+inference = StreamingInference(pipeline, source)
inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
prediction = inference()
```
@@ -347,21 +340,21 @@ To obtain the best results, make sure to use the following hyper-parameters:
`diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration:
```shell
-diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021
+diart.benchmark /wav/dir --reference /rttm/dir --tau-active=0.555 --rho-update=0.422 --delta-new=1.517 --segmentation pyannote/segmentation@Interspeech2021
```
or using the inference API:
```python
from diart.inference import Benchmark, Parallelize
-from diart import OnlineSpeakerDiarization, PipelineConfig
+from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import SegmentationModel
benchmark = Benchmark("/wav/dir", "/rttm/dir")
name = "pyannote/segmentation@Interspeech2021"
segmentation = SegmentationModel.from_pyannote(name)
-config = PipelineConfig(
+config = SpeakerDiarizationConfig(
# Set the model used in the paper
segmentation=segmentation,
step=0.5,
@@ -370,12 +363,12 @@ config = PipelineConfig(
rho_update=0.422,
delta_new=1.517
)
-benchmark(OnlineSpeakerDiarization, config)
+benchmark(SpeakerDiarization, config)
# Run the same benchmark in parallel
p_benchmark = Parallelize(benchmark, num_workers=4)
if __name__ == "__main__": # Needed for multiprocessing
- p_benchmark(OnlineSpeakerDiarization, config)
+ p_benchmark(SpeakerDiarization, config)
```
This pre-calculates model outputs in batches, so it runs a lot faster.
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 00000000..f62b4274
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,12 @@
+name: diart
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.8
+ - portaudio=19.6.*
+ - pysoundfile=0.12.*
+ - ffmpeg[version='<4.4']
+ - pip
+ - pip:
+ - .
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 50662023..e0d93213 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,7 +8,7 @@ tqdm>=4.64.0
pandas>=1.4.2
torch>=1.12.1
torchvision>=0.14.0
-torchaudio>=0.12.1,<1.0
+torchaudio>=2.0.2
pyannote.audio>=2.1.1
pyannote.core>=4.5
pyannote.database>=4.1.1
diff --git a/setup.cfg b/setup.cfg
index 594c876e..f38a612e 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,12 +1,12 @@
[metadata]
name=diart
-version=0.7.0
+version=0.8.0
author=Juan Manuel Coria
-description=Speaker diarization in real time
+description=Streaming speaker diarization in real-time
long_description=file: README.md
long_description_content_type=text/markdown
keywords=speaker diarization, streaming, online, real time, rxpy
-url=https://github.com/juanmc2005/StreamingSpeakerDiarization
+url=https://github.com/juanmc2005/diart
license=MIT
classifiers=
Development Status :: 4 - Beta
@@ -30,7 +30,7 @@ install_requires=
pandas>=1.4.2
torch>=1.12.1
torchvision>=0.14.0
- torchaudio>=0.12.1,<1.0
+ torchaudio>=2.0.2
pyannote.audio>=2.1.1
pyannote.core>=4.5
pyannote.database>=4.1.1
diff --git a/src/diart/__init__.py b/src/diart/__init__.py
index c9692638..4bd51327 100644
--- a/src/diart/__init__.py
+++ b/src/diart/__init__.py
@@ -1,6 +1,8 @@
from .blocks import (
- OnlineSpeakerDiarization,
- BasePipeline,
+ SpeakerDiarization,
+ Pipeline,
+ SpeakerDiarizationConfig,
PipelineConfig,
- BasePipelineConfig,
+ VoiceActivityDetection,
+ VoiceActivityDetectionConfig,
)
diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py
index d16df41e..e89caa28 100644
--- a/src/diart/argdoc.py
+++ b/src/diart/argdoc.py
@@ -1,5 +1,6 @@
SEGMENTATION = "Segmentation model name from pyannote"
EMBEDDING = "Embedding model name from pyannote"
+DURATION = "Chunk duration (in seconds)"
STEP = "Sliding window step (in seconds)"
LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION"
TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1"
diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py
index 59a6ef36..15cf81d9 100644
--- a/src/diart/blocks/__init__.py
+++ b/src/diart/blocks/__init__.py
@@ -13,6 +13,7 @@
OverlapAwareSpeakerEmbedding,
)
from .segmentation import SpeakerSegmentation
-from .diarization import OnlineSpeakerDiarization, BasePipeline
-from .config import BasePipelineConfig, PipelineConfig
+from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
+from .base import PipelineConfig, Pipeline
from .utils import Binarize, Resample, AdjustVolume
+from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig
diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py
index b6352a28..aa5e6a1e 100644
--- a/src/diart/blocks/aggregation.py
+++ b/src/diart/blocks/aggregation.py
@@ -1,3 +1,4 @@
+from abc import ABC, abstractmethod
from typing import Optional, List
import numpy as np
@@ -5,7 +6,7 @@
from typing_extensions import Literal
-class AggregationStrategy:
+class AggregationStrategy(ABC):
"""Abstract class representing a strategy to aggregate overlapping buffers
Parameters
@@ -17,14 +18,18 @@ class AggregationStrategy:
"""
def __init__(self, cropping_mode: Literal["strict", "loose", "center"] = "loose"):
- assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`"
+ assert cropping_mode in [
+ "strict",
+ "loose",
+ "center",
+ ], f"Invalid cropping mode `{cropping_mode}`"
self.cropping_mode = cropping_mode
@staticmethod
def build(
name: Literal["mean", "hamming", "first"],
- cropping_mode: Literal["strict", "loose", "center"] = "loose"
- ) -> 'AggregationStrategy':
+ cropping_mode: Literal["strict", "loose", "center"] = "loose",
+ ) -> "AggregationStrategy":
"""Build an AggregationStrategy instance based on its name"""
assert name in ("mean", "hamming", "first")
if name == "mean":
@@ -34,7 +39,9 @@ def build(
else:
return FirstOnlyStrategy(cropping_mode)
- def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> SlidingWindowFeature:
+ def __call__(
+ self, buffers: List[SlidingWindowFeature], focus: Segment
+ ) -> SlidingWindowFeature:
"""Aggregate chunks over a specific region.
Parameters
@@ -52,20 +59,23 @@ def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> Slidi
aggregation = self.aggregate(buffers, focus)
resolution = focus.duration / aggregation.shape[0]
resolution = SlidingWindow(
- start=focus.start,
- duration=resolution,
- step=resolution
+ start=focus.start, duration=resolution, step=resolution
)
return SlidingWindowFeature(aggregation, resolution)
- def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
- raise NotImplementedError
+ @abstractmethod
+ def aggregate(
+ self, buffers: List[SlidingWindowFeature], focus: Segment
+ ) -> np.ndarray:
+ pass
class HammingWeightedAverageStrategy(AggregationStrategy):
"""Compute the average weighted by the corresponding Hamming-window aligned to each buffer"""
- def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
+ def aggregate(
+ self, buffers: List[SlidingWindowFeature], focus: Segment
+ ) -> np.ndarray:
num_frames, num_speakers = buffers[0].data.shape
hamming, intersection = [], []
for buffer in buffers:
@@ -85,19 +95,25 @@ def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.n
class AverageStrategy(AggregationStrategy):
"""Compute a simple average over the focus region"""
- def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
+ def aggregate(
+ self, buffers: List[SlidingWindowFeature], focus: Segment
+ ) -> np.ndarray:
# Stack all overlapping regions
- intersection = np.stack([
- buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration)
- for buffer in buffers
- ])
+ intersection = np.stack(
+ [
+ buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration)
+ for buffer in buffers
+ ]
+ )
return np.mean(intersection, axis=0)
class FirstOnlyStrategy(AggregationStrategy):
"""Instead of aggregating, keep the first focus region in the buffer list"""
- def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
+ def aggregate(
+ self, buffers: List[SlidingWindowFeature], focus: Segment
+ ) -> np.ndarray:
return buffers[0].crop(focus, mode=self.cropping_mode, fixed=focus.duration)
@@ -149,12 +165,16 @@ def __init__(
step: float,
latency: Optional[float] = None,
strategy: Literal["mean", "hamming", "first"] = "hamming",
- cropping_mode: Literal["strict", "loose", "center"] = "loose"
+ cropping_mode: Literal["strict", "loose", "center"] = "loose",
):
self.step = step
self.latency = latency
self.strategy = strategy
- assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`"
+ assert cropping_mode in [
+ "strict",
+ "loose",
+ "center",
+ ], f"Invalid cropping mode `{cropping_mode}`"
self.cropping_mode = cropping_mode
if self.latency is None:
@@ -169,7 +189,7 @@ def _prepend(
self,
output_window: SlidingWindowFeature,
output_region: Segment,
- buffers: List[SlidingWindowFeature]
+ buffers: List[SlidingWindowFeature],
):
# FIXME instead of prepending the output of the first chunk,
# add padding of `chunk_duration - latency` seconds at the
@@ -187,7 +207,7 @@ def _prepend(
resolution = output_region.end / first_output.shape[0]
output_window = SlidingWindowFeature(
first_output,
- SlidingWindow(start=0, duration=resolution, step=resolution)
+ SlidingWindow(start=0, duration=resolution, step=resolution),
)
return output_window
diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py
new file mode 100644
index 00000000..f6ca3a33
--- /dev/null
+++ b/src/diart/blocks/base.py
@@ -0,0 +1,95 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Tuple, Sequence, Text
+
+from pyannote.core import SlidingWindowFeature
+from pyannote.metrics.base import BaseMetric
+
+from .. import utils
+from ..audio import FilePath, AudioLoader
+
+
+@dataclass
+class HyperParameter:
+ name: Text
+ low: float
+ high: float
+
+ @staticmethod
+ def from_name(name: Text) -> "HyperParameter":
+ if name == "tau_active":
+ return TauActive
+ if name == "rho_update":
+ return RhoUpdate
+ if name == "delta_new":
+ return DeltaNew
+ raise ValueError(f"Hyper-parameter '{name}' not recognized")
+
+
+TauActive = HyperParameter("tau_active", low=0, high=1)
+RhoUpdate = HyperParameter("rho_update", low=0, high=1)
+DeltaNew = HyperParameter("delta_new", low=0, high=2)
+
+
+class PipelineConfig(ABC):
+ @property
+ @abstractmethod
+ def duration(self) -> float:
+ pass
+
+ @property
+ @abstractmethod
+ def step(self) -> float:
+ pass
+
+ @property
+ @abstractmethod
+ def latency(self) -> float:
+ pass
+
+ @property
+ @abstractmethod
+ def sample_rate(self) -> int:
+ pass
+
+ def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
+ file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
+ right = utils.get_padding_right(self.latency, self.step)
+ left = utils.get_padding_left(file_duration + right, self.duration)
+ return left, right
+
+
+class Pipeline(ABC):
+ @staticmethod
+ @abstractmethod
+ def get_config_class() -> type:
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def suggest_metric() -> BaseMetric:
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def hyper_parameters() -> Sequence[HyperParameter]:
+ pass
+
+ @property
+ @abstractmethod
+ def config(self) -> PipelineConfig:
+ pass
+
+ @abstractmethod
+ def reset(self):
+ pass
+
+ @abstractmethod
+ def set_timestamp_shift(self, shift: float):
+ pass
+
+ @abstractmethod
+ def __call__(
+ self, waveforms: Sequence[SlidingWindowFeature]
+ ) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
+ pass
diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py
index 882001b9..b7217c0a 100644
--- a/src/diart/blocks/clustering.py
+++ b/src/diart/blocks/clustering.py
@@ -27,13 +27,14 @@ class OnlineSpeakerClustering:
max_speakers: int
Maximum number of global speakers to track through a conversation. Defaults to 20.
"""
+
def __init__(
self,
tau_active: float,
rho_update: float,
delta_new: float,
metric: Optional[str] = "cosine",
- max_speakers: int = 20
+ max_speakers: int = 20,
):
self.tau_active = tau_active
self.rho_update = rho_update
@@ -116,9 +117,7 @@ def add_center(self, embedding: np.ndarray) -> int:
return center
def identify(
- self,
- segmentation: SlidingWindowFeature,
- embeddings: torch.Tensor
+ self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor
) -> SpeakerMap:
"""Identify the centroids to which the input speaker embeddings belong.
@@ -135,15 +134,18 @@ def identify(
A mapping from local speakers to global speakers.
"""
embeddings = embeddings.detach().cpu().numpy()
- active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0]
- long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0]
+ active_speakers = np.where(
+ np.max(segmentation.data, axis=0) >= self.tau_active
+ )[0]
+ long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[
+ 0
+ ]
num_local_speakers = segmentation.data.shape[1]
if self.centers is None:
self.init_centers(embeddings.shape[1])
assignments = [
- (spk, self.add_center(embeddings[spk]))
- for spk in active_speakers
+ (spk, self.add_center(embeddings[spk])) for spk in active_speakers
]
return SpeakerMapBuilder.hard_map(
shape=(num_local_speakers, self.max_speakers),
@@ -154,18 +156,16 @@ def identify(
# Obtain a mapping based on distances between embeddings and centers
dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric)
# Remove any assignments containing invalid speakers
- inactive_speakers = np.array([
- spk for spk in range(num_local_speakers)
- if spk not in active_speakers
- ])
+ inactive_speakers = np.array(
+ [spk for spk in range(num_local_speakers) if spk not in active_speakers]
+ )
dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers)
# Keep assignments under the distance threshold
valid_map = dist_map.unmap_threshold(self.delta_new)
# Some speakers might be unidentified
missed_speakers = [
- s for s in active_speakers
- if not valid_map.is_source_speaker_mapped(s)
+ s for s in active_speakers if not valid_map.is_source_speaker_mapped(s)
]
# Add assignments to new centers if possible
@@ -205,8 +205,10 @@ def identify(
return valid_map
- def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) -> SlidingWindowFeature:
+ def __call__(
+ self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor
+ ) -> SlidingWindowFeature:
return SlidingWindowFeature(
self.identify(segmentation, embeddings).apply(segmentation.data),
- segmentation.sliding_window
+ segmentation.sliding_window,
)
diff --git a/src/diart/blocks/config.py b/src/diart/blocks/config.py
deleted file mode 100644
index d8e2a656..00000000
--- a/src/diart/blocks/config.py
+++ /dev/null
@@ -1,153 +0,0 @@
-from typing import Any, Optional, Union, Tuple
-
-import numpy as np
-import torch
-from typing_extensions import Literal
-
-from .. import models as m
-from .. import utils
-from ..audio import FilePath, AudioLoader
-
-
-class BasePipelineConfig:
- @property
- def duration(self) -> float:
- raise NotImplementedError
-
- @property
- def step(self) -> float:
- raise NotImplementedError
-
- @property
- def latency(self) -> float:
- raise NotImplementedError
-
- @property
- def sample_rate(self) -> int:
- raise NotImplementedError
-
- @staticmethod
- def from_dict(data: Any) -> 'BasePipelineConfig':
- raise NotImplementedError
-
- def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
- file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
- right = utils.get_padding_right(self.latency, self.step)
- left = utils.get_padding_left(file_duration + right, self.duration)
- return left, right
-
- def optimal_block_size(self) -> int:
- return int(np.rint(self.step * self.sample_rate))
-
-
-class PipelineConfig(BasePipelineConfig):
- def __init__(
- self,
- segmentation: Optional[m.SegmentationModel] = None,
- embedding: Optional[m.EmbeddingModel] = None,
- duration: Optional[float] = None,
- step: float = 0.5,
- latency: Optional[Union[float, Literal["max", "min"]]] = None,
- tau_active: float = 0.6,
- rho_update: float = 0.3,
- delta_new: float = 1,
- gamma: float = 3,
- beta: float = 10,
- max_speakers: int = 20,
- device: Optional[torch.device] = None,
- **kwargs,
- ):
- # Default segmentation model is pyannote/segmentation
- self.segmentation = segmentation
- if self.segmentation is None:
- self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
-
- # Default duration is the one given by the segmentation model
- self._duration = duration
-
- # Expected sample rate is given by the segmentation model
- self._sample_rate: Optional[int] = None
-
- # Default embedding model is pyannote/embedding
- self.embedding = embedding
- if self.embedding is None:
- self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
-
- # Latency defaults to the step duration
- self._step = step
- self._latency = latency
- if self._latency is None or self._latency == "min":
- self._latency = self._step
- elif self._latency == "max":
- self._latency = self._duration
-
- self.tau_active = tau_active
- self.rho_update = rho_update
- self.delta_new = delta_new
- self.gamma = gamma
- self.beta = beta
- self.max_speakers = max_speakers
-
- self.device = device
- if self.device is None:
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- @staticmethod
- def from_dict(data: Any) -> 'PipelineConfig':
- # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
- device = utils.get(data, "device", None)
- if device is None:
- device = torch.device("cpu") if utils.get(data, "cpu", False) else None
-
- # Instantiate models
- hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
- segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
- segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
- embedding = utils.get(data, "embedding", "pyannote/embedding")
- embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)
-
- # Hyper-parameters and their aliases
- tau = utils.get(data, "tau_active", None)
- if tau is None:
- tau = utils.get(data, "tau", 0.6)
- rho = utils.get(data, "rho_update", None)
- if rho is None:
- rho = utils.get(data, "rho", 0.3)
- delta = utils.get(data, "delta_new", None)
- if delta is None:
- delta = utils.get(data, "delta", 1)
-
- return PipelineConfig(
- segmentation=segmentation,
- embedding=embedding,
- duration=utils.get(data, "duration", None),
- step=utils.get(data, "step", 0.5),
- latency=utils.get(data, "latency", None),
- tau_active=tau,
- rho_update=rho,
- delta_new=delta,
- gamma=utils.get(data, "gamma", 3),
- beta=utils.get(data, "beta", 10),
- max_speakers=utils.get(data, "max_speakers", 20),
- device=device,
- )
-
- @property
- def duration(self) -> float:
- if self._duration is None:
- self._duration = self.segmentation.duration
- return self._duration
-
- @property
- def step(self) -> float:
- return self._step
-
- @property
- def latency(self) -> float:
- return self._latency
-
- @property
- def sample_rate(self) -> int:
- if self._sample_rate is None:
- self._sample_rate = self.segmentation.sample_rate
- return self._sample_rate
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
index 7f0e162c..fab83c36 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/blocks/diarization.py
@@ -1,49 +1,111 @@
-from typing import Optional, Tuple, Sequence
+from __future__ import annotations
+
+from typing import Sequence
import numpy as np
import torch
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
+from pyannote.metrics.base import BaseMetric
+from pyannote.metrics.diarization import DiarizationErrorRate
+from typing_extensions import Literal
+from . import base
from .aggregation import DelayedAggregation
from .clustering import OnlineSpeakerClustering
from .embedding import OverlapAwareSpeakerEmbedding
from .segmentation import SpeakerSegmentation
from .utils import Binarize
-from .config import BasePipelineConfig, PipelineConfig
+from .. import models as m
-class BasePipeline:
- @staticmethod
- def get_config_class() -> type:
- raise NotImplementedError
+class SpeakerDiarizationConfig(base.PipelineConfig):
+ def __init__(
+ self,
+ segmentation: m.SegmentationModel | None = None,
+ embedding: m.EmbeddingModel | None = None,
+ duration: float | None = None,
+ step: float = 0.5,
+ latency: float | Literal["max", "min"] | None = None,
+ tau_active: float = 0.6,
+ rho_update: float = 0.3,
+ delta_new: float = 1,
+ gamma: float = 3,
+ beta: float = 10,
+ max_speakers: int = 20,
+ device: torch.device | None = None,
+ **kwargs,
+ ):
+ # Default segmentation model is pyannote/segmentation
+ self.segmentation = segmentation or m.SegmentationModel.from_pyannote(
+ "pyannote/segmentation"
+ )
+
+ # Default embedding model is pyannote/embedding
+ self.embedding = embedding or m.EmbeddingModel.from_pyannote(
+ "pyannote/embedding"
+ )
+
+ self._duration = duration
+ self._sample_rate: int | None = None
+
+ # Latency defaults to the step duration
+ self._step = step
+ self._latency = latency
+ if self._latency is None or self._latency == "min":
+ self._latency = self._step
+ elif self._latency == "max":
+ self._latency = self._duration
+
+ self.tau_active = tau_active
+ self.rho_update = rho_update
+ self.delta_new = delta_new
+ self.gamma = gamma
+ self.beta = beta
+ self.max_speakers = max_speakers
+
+ self.device = device or torch.device(
+ "cuda" if torch.cuda.is_available() else "cpu"
+ )
@property
- def config(self) -> BasePipelineConfig:
- raise NotImplementedError
+ def duration(self) -> float:
+ # Default duration is the one given by the segmentation model
+ if self._duration is None:
+ self._duration = self.segmentation.duration
+ return self._duration
- def reset(self):
- raise NotImplementedError
+ @property
+ def step(self) -> float:
+ return self._step
- def set_timestamp_shift(self, shift: float):
- raise NotImplementedError
+ @property
+ def latency(self) -> float:
+ return self._latency
- def __call__(
- self,
- waveforms: Sequence[SlidingWindowFeature]
- ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
- raise NotImplementedError
+ @property
+ def sample_rate(self) -> int:
+ # Expected sample rate is given by the segmentation model
+ if self._sample_rate is None:
+ self._sample_rate = self.segmentation.sample_rate
+ return self._sample_rate
-class OnlineSpeakerDiarization(BasePipeline):
- def __init__(self, config: Optional[PipelineConfig] = None):
- self._config = PipelineConfig() if config is None else config
+class SpeakerDiarization(base.Pipeline):
+ def __init__(self, config: SpeakerDiarizationConfig | None = None):
+ self._config = SpeakerDiarizationConfig() if config is None else config
msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
assert self._config.step <= self._config.latency <= self._config.duration, msg
- self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device)
+ self.segmentation = SpeakerSegmentation(
+ self._config.segmentation, self._config.device
+ )
self.embedding = OverlapAwareSpeakerEmbedding(
- self._config.embedding, self._config.gamma, self._config.beta, norm=1, device=self._config.device
+ self._config.embedding,
+ self._config.gamma,
+ self._config.beta,
+ norm=1,
+ device=self._config.device,
)
self.pred_aggregation = DelayedAggregation(
self._config.step,
@@ -67,10 +129,18 @@ def __init__(self, config: Optional[PipelineConfig] = None):
@staticmethod
def get_config_class() -> type:
- return PipelineConfig
+ return SpeakerDiarizationConfig
+
+ @staticmethod
+ def suggest_metric() -> BaseMetric:
+ return DiarizationErrorRate(collar=0, skip_overlap=False)
+
+ @staticmethod
+ def hyper_parameters() -> Sequence[base.HyperParameter]:
+ return [base.TauActive, base.RhoUpdate, base.DeltaNew]
@property
- def config(self) -> PipelineConfig:
+ def config(self) -> SpeakerDiarizationConfig:
return self._config
def set_timestamp_shift(self, shift: float):
@@ -88,9 +158,8 @@ def reset(self):
self.chunk_buffer, self.pred_buffer = [], []
def __call__(
- self,
- waveforms: Sequence[SlidingWindowFeature]
- ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
+ self, waveforms: Sequence[SlidingWindowFeature]
+ ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
batch_size = len(waveforms)
msg = "Pipeline expected at least 1 input"
assert batch_size >= 1, msg
@@ -98,13 +167,17 @@ def __call__(
# Create batch from chunk sequence, shape (batch, samples, channels)
batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
- expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
+ expected_num_samples = int(
+ np.rint(self.config.duration * self.config.sample_rate)
+ )
msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
assert batch.shape[1] == expected_num_samples, msg
# Extract segmentation and embeddings
segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
- embeddings = self.embedding(batch, segmentations) # shape (batch, speakers, emb_dim)
+ embeddings = self.embedding(
+ batch, segmentations
+ ) # shape (batch, speakers, emb_dim)
seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
@@ -133,7 +206,9 @@ def __call__(
# Shift prediction timestamps if required
if self.timestamp_shift != 0:
shifted_agg_prediction = Annotation(agg_prediction.uri)
- for segment, track, speaker in agg_prediction.itertracks(yield_label=True):
+ for segment, track, speaker in agg_prediction.itertracks(
+ yield_label=True
+ ):
new_segment = Segment(
segment.start + self.timestamp_shift,
segment.end + self.timestamp_shift,
diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py
index 7aa31c05..5cd7c39e 100644
--- a/src/diart/blocks/embedding.py
+++ b/src/diart/blocks/embedding.py
@@ -22,12 +22,14 @@ def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None)
def from_pyannote(
model,
use_hf_token: Union[Text, bool, None] = True,
- device: Optional[torch.device] = None
- ) -> 'SpeakerEmbedding':
+ device: Optional[torch.device] = None,
+ ) -> "SpeakerEmbedding":
emb_model = EmbeddingModel.from_pyannote(model, use_hf_token)
return SpeakerEmbedding(emb_model, device)
- def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor:
+ def __call__(
+ self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None
+ ) -> torch.Tensor:
"""
Calculate speaker embeddings of input audio.
If weights are given, calculate many speaker embeddings from the same waveform.
@@ -58,7 +60,7 @@ def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeature
self.model(inputs, weights),
"(batch spk) feat -> batch spk feat",
batch=batch_size,
- spk=num_speakers
+ spk=num_speakers,
)
else:
output = self.model(inputs)
@@ -76,6 +78,7 @@ class OverlappedSpeechPenalty:
Temperature parameter (actually 1/beta) to lower joint speaker activations.
Defaults to 10.
"""
+
def __init__(self, gamma: float = 3, beta: float = 10):
self.gamma = gamma
self.beta = beta
@@ -106,7 +109,11 @@ def __call__(self, embeddings: torch.Tensor) -> torch.Tensor:
batch_size2, num_speakers2, _ = embeddings.shape
assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2
with torch.no_grad():
- norm_embs = self.norm * embeddings / torch.norm(embeddings, p=2, dim=-1, keepdim=True)
+ norm_embs = (
+ self.norm
+ * embeddings
+ / torch.norm(embeddings, p=2, dim=-1, keepdim=True)
+ )
return norm_embs
@@ -131,6 +138,7 @@ class OverlapAwareSpeakerEmbedding:
The device on which to run the embedding model.
Defaults to GPU if available or CPU if not.
"""
+
def __init__(
self,
model: EmbeddingModel,
@@ -155,5 +163,7 @@ def from_pyannote(
model = EmbeddingModel.from_pyannote(model, use_hf_token)
return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device)
- def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor:
+ def __call__(
+ self, waveform: TemporalFeatures, segmentation: TemporalFeatures
+ ) -> torch.Tensor:
return self.normalize(self.embedding(waveform, self.osp(segmentation)))
diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py
index 8fda3ffc..e946c748 100644
--- a/src/diart/blocks/segmentation.py
+++ b/src/diart/blocks/segmentation.py
@@ -21,8 +21,8 @@ def __init__(self, model: SegmentationModel, device: Optional[torch.device] = No
def from_pyannote(
model,
use_hf_token: Union[Text, bool, None] = True,
- device: Optional[torch.device] = None
- ) -> 'SpeakerSegmentation':
+ device: Optional[torch.device] = None,
+ ) -> "SpeakerSegmentation":
seg_model = SegmentationModel.from_pyannote(model, use_hf_token)
return SpeakerSegmentation(seg_model, device)
@@ -40,6 +40,9 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
The batch dimension is omitted if waveform is a `SlidingWindowFeature`.
"""
with torch.no_grad():
- wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample")
+ wave = rearrange(
+ self.formatter.cast(waveform),
+ "batch sample channel -> batch channel sample",
+ )
output = self.model(wave.to(self.device)).cpu()
return self.formatter.restore_type(output)
diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py
index 02594e3d..9c0afc3e 100644
--- a/src/diart/blocks/utils.py
+++ b/src/diart/blocks/utils.py
@@ -69,12 +69,21 @@ class Resample:
resample_rate: int
Sample rate of the output
"""
- def __init__(self, sample_rate: int, resample_rate: int):
- self.resample = T.Resample(sample_rate, resample_rate)
+
+ def __init__(
+ self,
+ sample_rate: int,
+ resample_rate: int,
+ device: Optional[torch.device] = None,
+ ):
+ self.device = device
+ if self.device is None:
+ self.device = torch.device("cpu")
+ self.resample = T.Resample(sample_rate, resample_rate).to(self.device)
self.formatter = TemporalFeatureFormatter()
def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
- wav = self.formatter.cast(waveform) # shape (batch, samples, 1)
+ wav = self.formatter.cast(waveform).to(self.device) # shape (batch, samples, 1)
with torch.no_grad():
resampled_wav = self.resample(wav.transpose(-1, -2)).transpose(-1, -2)
return self.formatter.restore_type(resampled_wav)
@@ -90,6 +99,7 @@ class AdjustVolume:
volume_in_db: float
Target volume in dB.
"""
+
def __init__(self, volume_in_db: float):
self.target_db = volume_in_db
self.formatter = TemporalFeatureFormatter()
@@ -108,7 +118,9 @@ def get_volumes(waveforms: torch.Tensor) -> torch.Tensor:
volumes: torch.Tensor
Audio chunk volumes per channel. Shape (batch, 1, channels)
"""
- return 10 * torch.log10(torch.mean(torch.abs(waveforms) ** 2, dim=1, keepdim=True))
+ return 10 * torch.log10(
+ torch.mean(torch.abs(waveforms) ** 2, dim=1, keepdim=True)
+ )
def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
wav = self.formatter.cast(waveform) # shape (batch, samples, channels)
diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py
new file mode 100644
index 00000000..0edd3e0b
--- /dev/null
+++ b/src/diart/blocks/vad.py
@@ -0,0 +1,196 @@
+from __future__ import annotations
+
+from typing import Sequence
+
+import numpy as np
+import torch
+from pyannote.core import (
+ Annotation,
+ Timeline,
+ SlidingWindowFeature,
+ SlidingWindow,
+ Segment,
+)
+from pyannote.metrics.base import BaseMetric
+from pyannote.metrics.detection import DetectionErrorRate
+from typing_extensions import Literal
+
+from . import base
+from .aggregation import DelayedAggregation
+from .segmentation import SpeakerSegmentation
+from .utils import Binarize
+from .. import models as m
+from .. import utils
+
+
+class VoiceActivityDetectionConfig(base.PipelineConfig):
+ def __init__(
+ self,
+ segmentation: m.SegmentationModel | None = None,
+ duration: float | None = None,
+ step: float = 0.5,
+ latency: float | Literal["max", "min"] | None = None,
+ tau_active: float = 0.6,
+ device: torch.device | None = None,
+ **kwargs,
+ ):
+ # Default segmentation model is pyannote/segmentation
+ self.segmentation = segmentation or m.SegmentationModel.from_pyannote(
+ "pyannote/segmentation"
+ )
+
+ self._duration = duration
+ self._step = step
+ self._sample_rate: int | None = None
+
+ # Latency defaults to the step duration
+ self._latency = latency
+ if self._latency is None or self._latency == "min":
+ self._latency = self._step
+ elif self._latency == "max":
+ self._latency = self._duration
+
+ self.tau_active = tau_active
+ self.device = device or torch.device(
+ "cuda" if torch.cuda.is_available() else "cpu"
+ )
+
+ @property
+ def duration(self) -> float:
+ # Default duration is the one given by the segmentation model
+ if self._duration is None:
+ self._duration = self.segmentation.duration
+ return self._duration
+
+ @property
+ def step(self) -> float:
+ return self._step
+
+ @property
+ def latency(self) -> float:
+ return self._latency
+
+ @property
+ def sample_rate(self) -> int:
+ # Expected sample rate is given by the segmentation model
+ if self._sample_rate is None:
+ self._sample_rate = self.segmentation.sample_rate
+ return self._sample_rate
+
+
+class VoiceActivityDetection(base.Pipeline):
+ def __init__(self, config: VoiceActivityDetectionConfig | None = None):
+ self._config = VoiceActivityDetectionConfig() if config is None else config
+
+ msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
+ assert self._config.step <= self._config.latency <= self._config.duration, msg
+
+ self.segmentation = SpeakerSegmentation(
+ self._config.segmentation, self._config.device
+ )
+ self.pred_aggregation = DelayedAggregation(
+ self._config.step,
+ self._config.latency,
+ strategy="hamming",
+ cropping_mode="loose",
+ )
+ self.audio_aggregation = DelayedAggregation(
+ self._config.step,
+ self._config.latency,
+ strategy="first",
+ cropping_mode="center",
+ )
+ self.binarize = Binarize(self._config.tau_active)
+
+ # Internal state, handle with care
+ self.timestamp_shift = 0
+ self.chunk_buffer, self.pred_buffer = [], []
+
+ @staticmethod
+ def get_config_class() -> type:
+ return VoiceActivityDetectionConfig
+
+ @staticmethod
+ def suggest_metric() -> BaseMetric:
+ return DetectionErrorRate(collar=0, skip_overlap=False)
+
+ @staticmethod
+ def hyper_parameters() -> Sequence[base.HyperParameter]:
+ return [base.TauActive]
+
+ @property
+ def config(self) -> base.PipelineConfig:
+ return self._config
+
+ def reset(self):
+ self.set_timestamp_shift(0)
+ self.chunk_buffer, self.pred_buffer = [], []
+
+ def set_timestamp_shift(self, shift: float):
+ self.timestamp_shift = shift
+
+ def __call__(
+ self,
+ waveforms: Sequence[SlidingWindowFeature],
+ ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
+ batch_size = len(waveforms)
+ msg = "Pipeline expected at least 1 input"
+ assert batch_size >= 1, msg
+
+ # Create batch from chunk sequence, shape (batch, samples, channels)
+ batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
+
+ expected_num_samples = int(
+ np.rint(self.config.duration * self.config.sample_rate)
+ )
+ msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
+ assert batch.shape[1] == expected_num_samples, msg
+
+ # Extract segmentation
+ segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
+ voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[
+ 0
+ ] # shape (batch, frames, 1)
+
+ seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
+
+ outputs = []
+ for wav, vad in zip(waveforms, voice_detection):
+ # Add timestamps to segmentation
+ sw = SlidingWindow(
+ start=wav.extent.start,
+ duration=seg_resolution,
+ step=seg_resolution,
+ )
+ vad = SlidingWindowFeature(vad.cpu().numpy(), sw)
+
+ # Update sliding buffer
+ self.chunk_buffer.append(wav)
+ self.pred_buffer.append(vad)
+
+ # Aggregate buffer outputs for this time step
+ agg_waveform = self.audio_aggregation(self.chunk_buffer)
+ agg_prediction = self.pred_aggregation(self.pred_buffer)
+ agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False)
+
+ # Shift prediction timestamps if required
+ if self.timestamp_shift != 0:
+ shifted_agg_prediction = Timeline(uri=agg_prediction.uri)
+ for segment in agg_prediction:
+ new_segment = Segment(
+ segment.start + self.timestamp_shift,
+ segment.end + self.timestamp_shift,
+ )
+ shifted_agg_prediction.add(new_segment)
+ agg_prediction = shifted_agg_prediction
+
+ # Convert timeline into annotation with single speaker "speech"
+ agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech"))
+ outputs.append((agg_prediction, agg_waveform))
+
+ # Make place for new chunks in buffer if required
+ if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows:
+ self.chunk_buffer = self.chunk_buffer[1:]
+ self.pred_buffer = self.pred_buffer[1:]
+
+ return outputs
diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py
index b6a3f9ff..b5a296d1 100644
--- a/src/diart/console/benchmark.py
+++ b/src/diart/console/benchmark.py
@@ -1,39 +1,116 @@
import argparse
from pathlib import Path
-import diart.argdoc as argdoc
import pandas as pd
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
+import torch
+
+from diart import argdoc
+from diart import models as m
+from diart import utils
from diart.inference import Benchmark, Parallelize
def run():
parser = argparse.ArgumentParser()
- parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
- parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
- help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
- parser.add_argument("--embedding", default="pyannote/embedding", type=str,
- help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
- parser.add_argument("--reference", type=Path,
- help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
- parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
- parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
- parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
- parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3")
- parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1")
- parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3")
- parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
- parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
- parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32")
- parser.add_argument("--num-workers", default=0, type=int,
- help=f"{argdoc.NUM_WORKERS}. Defaults to 0 (no parallelism)")
- parser.add_argument("--cpu", dest="cpu", action="store_true",
- help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
- parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing")
- parser.add_argument("--hf-token", default="true", type=str,
- help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
+ parser.add_argument(
+ "root",
+ type=Path,
+ help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)",
+ )
+ parser.add_argument(
+ "--pipeline",
+ default="SpeakerDiarization",
+ type=str,
+ help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'",
+ )
+ parser.add_argument(
+ "--segmentation",
+ default="pyannote/segmentation",
+ type=str,
+ help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation",
+ )
+ parser.add_argument(
+ "--embedding",
+ default="pyannote/embedding",
+ type=str,
+ help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding",
+ )
+ parser.add_argument(
+ "--reference",
+ type=Path,
+ help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files",
+ )
+ parser.add_argument(
+ "--duration",
+ type=float,
+ help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
+ )
+ parser.add_argument(
+ "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
+ )
+ parser.add_argument(
+ "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
+ )
+ parser.add_argument(
+ "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
+ )
+ parser.add_argument(
+ "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10"
+ )
+ parser.add_argument(
+ "--max-speakers",
+ default=20,
+ type=int,
+ help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20",
+ )
+ parser.add_argument(
+ "--batch-size",
+ default=32,
+ type=int,
+ help=f"{argdoc.BATCH_SIZE}. Defaults to 32",
+ )
+ parser.add_argument(
+ "--num-workers",
+ default=0,
+ type=int,
+ help=f"{argdoc.NUM_WORKERS}. Defaults to 0 (no parallelism)",
+ )
+ parser.add_argument(
+ "--cpu",
+ dest="cpu",
+ action="store_true",
+ help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise",
+ )
+ parser.add_argument(
+ "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing"
+ )
+ parser.add_argument(
+ "--hf-token",
+ default="true",
+ type=str,
+ help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
+ )
args = parser.parse_args()
+ # Resolve device
+ args.device = torch.device("cpu") if args.cpu else None
+
+ # Resolve models
+ hf_token = utils.parse_hf_token_arg(args.hf_token)
+ args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
+ args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
+
+ pipeline_class = utils.get_pipeline_class(args.pipeline)
+
benchmark = Benchmark(
args.root,
args.reference,
@@ -43,11 +120,11 @@ def run():
batch_size=args.batch_size,
)
- config = PipelineConfig.from_dict(vars(args))
+ config = pipeline_class.get_config_class()(**vars(args))
if args.num_workers > 0:
benchmark = Parallelize(benchmark, args.num_workers)
- report = benchmark(OnlineSpeakerDiarization, config)
+ report = benchmark(pipeline_class, config)
if args.output is not None and isinstance(report, pd.DataFrame):
report.to_csv(args.output / "benchmark_report.csv")
diff --git a/src/diart/console/client.py b/src/diart/console/client.py
index 084dbc13..b3de36db 100644
--- a/src/diart/console/client.py
+++ b/src/diart/console/client.py
@@ -3,28 +3,25 @@
from threading import Thread
from typing import Text, Optional
-import diart.argdoc as argdoc
-import diart.sources as src
-import diart.utils as utils
-import numpy as np
import rx.operators as ops
from websocket import WebSocket
+from diart import argdoc
+from diart import sources as src
+from diart import utils
+
def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int):
# Create audio source
- block_size = int(np.rint(step * sample_rate))
source_components = source.split(":")
if source_components[0] != "microphone":
- audio_source = src.FileAudioSource(source, sample_rate)
+ audio_source = src.FileAudioSource(source, sample_rate, block_duration=step)
else:
device = int(source_components[1]) if len(source_components) > 1 else None
- audio_source = src.MicrophoneAudioSource(sample_rate, block_size, device)
+ audio_source = src.MicrophoneAudioSource(step, device)
# Encode audio, then send through websocket
- audio_source.stream.pipe(
- ops.map(utils.encode_audio)
- ).subscribe_(ws.send)
+ audio_source.stream.pipe(ops.map(utils.encode_audio)).subscribe_(ws.send)
# Start reading audio
audio_source.read()
@@ -41,18 +38,37 @@ def receive_audio(ws: WebSocket, output: Optional[Path]):
def run():
parser = argparse.ArgumentParser()
- parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'")
+ parser.add_argument(
+ "source",
+ type=str,
+ help="Path to an audio file | 'microphone' | 'microphone:'",
+ )
parser.add_argument("--host", required=True, type=str, help="Server host")
parser.add_argument("--port", required=True, type=int, help="Server port")
- parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
- parser.add_argument("-sr", "--sample-rate", default=16000, type=int, help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000")
- parser.add_argument("-o", "--output-file", type=Path, help="Output RTTM file. Defaults to no writing")
+ parser.add_argument(
+ "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "-sr",
+ "--sample-rate",
+ default=16000,
+ type=int,
+ help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000",
+ )
+ parser.add_argument(
+ "-o",
+ "--output-file",
+ type=Path,
+ help="Output RTTM file. Defaults to no writing",
+ )
args = parser.parse_args()
# Run websocket client
ws = WebSocket()
ws.connect(f"ws://{args.host}:{args.port}")
- sender = Thread(target=send_audio, args=[ws, args.source, args.step, args.sample_rate])
+ sender = Thread(
+ target=send_audio, args=[ws, args.source, args.step, args.sample_rate]
+ )
receiver = Thread(target=receive_audio, args=[ws, args.output_file])
sender.start()
receiver.start()
diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py
index 2f632d57..bc002e42 100644
--- a/src/diart/console/serve.py
+++ b/src/diart/console/serve.py
@@ -1,10 +1,13 @@
import argparse
from pathlib import Path
-import diart.argdoc as argdoc
-import diart.sources as src
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
-from diart.inference import RealTimeInference
+import torch
+
+from diart import argdoc
+from diart import models as m
+from diart import sources as src
+from diart import utils
+from diart.inference import StreamingInference
from diart.sinks import RTTMWriter
@@ -12,34 +15,91 @@ def run():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host")
parser.add_argument("--port", default=7007, type=int, help="Server port")
- parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
- help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
- parser.add_argument("--embedding", default="pyannote/embedding", type=str,
- help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
- parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
- parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
- parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
- parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3")
- parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1")
- parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3")
- parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
- parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
- parser.add_argument("--cpu", dest="cpu", action="store_true",
- help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
- parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing")
- parser.add_argument("--hf-token", default="true", type=str,
- help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
+ parser.add_argument(
+ "--pipeline",
+ default="SpeakerDiarization",
+ type=str,
+ help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'",
+ )
+ parser.add_argument(
+ "--segmentation",
+ default="pyannote/segmentation",
+ type=str,
+ help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation",
+ )
+ parser.add_argument(
+ "--embedding",
+ default="pyannote/embedding",
+ type=str,
+ help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding",
+ )
+ parser.add_argument(
+ "--duration",
+ type=float,
+ help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
+ )
+ parser.add_argument(
+ "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
+ )
+ parser.add_argument(
+ "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
+ )
+ parser.add_argument(
+ "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
+ )
+ parser.add_argument(
+ "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10"
+ )
+ parser.add_argument(
+ "--max-speakers",
+ default=20,
+ type=int,
+ help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20",
+ )
+ parser.add_argument(
+ "--cpu",
+ dest="cpu",
+ action="store_true",
+ help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise",
+ )
+ parser.add_argument(
+ "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing"
+ )
+ parser.add_argument(
+ "--hf-token",
+ default="true",
+ type=str,
+ help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
+ )
args = parser.parse_args()
- # Define online speaker diarization pipeline
- config = PipelineConfig.from_dict(vars(args))
- pipeline = OnlineSpeakerDiarization(config)
+ # Resolve device
+ args.device = torch.device("cpu") if args.cpu else None
+
+ # Resolve models
+ hf_token = utils.parse_hf_token_arg(args.hf_token)
+ args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
+ args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
+
+ # Resolve pipeline
+ pipeline_class = utils.get_pipeline_class(args.pipeline)
+ config = pipeline_class.get_config_class()(**vars(args))
+ pipeline = pipeline_class(config)
# Create websocket audio source
audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port)
# Run online inference
- inference = RealTimeInference(
+ inference = StreamingInference(
pipeline,
audio_source,
batch_size=1,
@@ -50,7 +110,9 @@ def run():
# Write to disk if required
if args.output is not None:
- inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm"))
+ inference.attach_observers(
+ RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")
+ )
# Send back responses as RTTM text lines
inference.attach_hooks(lambda ann_wav: audio_source.send(ann_wav[0].to_rttm()))
diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py
index d7218f07..713f3e99 100644
--- a/src/diart/console/stream.py
+++ b/src/diart/console/stream.py
@@ -1,57 +1,130 @@
import argparse
from pathlib import Path
-import diart.argdoc as argdoc
-import diart.sources as src
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
-from diart.inference import RealTimeInference
+import torch
+
+from diart import argdoc
+from diart import models as m
+from diart import sources as src
+from diart import utils
+from diart.inference import StreamingInference
from diart.sinks import RTTMWriter
def run():
parser = argparse.ArgumentParser()
- parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'")
- parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
- help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
- parser.add_argument("--embedding", default="pyannote/embedding", type=str,
- help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
- parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
- parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
- parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
- parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3")
- parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1")
- parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3")
- parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
- parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
- parser.add_argument("--no-plot", dest="no_plot", action="store_true", help="Skip plotting for faster inference")
- parser.add_argument("--cpu", dest="cpu", action="store_true",
- help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
- parser.add_argument("--output", type=str,
- help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file")
- parser.add_argument("--hf-token", default="true", type=str,
- help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
+ parser.add_argument(
+ "source",
+ type=str,
+ help="Path to an audio file | 'microphone' | 'microphone:'",
+ )
+ parser.add_argument(
+ "--pipeline",
+ default="SpeakerDiarization",
+ type=str,
+ help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'",
+ )
+ parser.add_argument(
+ "--segmentation",
+ default="pyannote/segmentation",
+ type=str,
+ help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation",
+ )
+ parser.add_argument(
+ "--embedding",
+ default="pyannote/embedding",
+ type=str,
+ help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding",
+ )
+ parser.add_argument(
+ "--duration",
+ type=float,
+ help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
+ )
+ parser.add_argument(
+ "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
+ )
+ parser.add_argument(
+ "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
+ )
+ parser.add_argument(
+ "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
+ )
+ parser.add_argument(
+ "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10"
+ )
+ parser.add_argument(
+ "--max-speakers",
+ default=20,
+ type=int,
+ help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20",
+ )
+ parser.add_argument(
+ "--no-plot",
+ dest="no_plot",
+ action="store_true",
+ help="Skip plotting for faster inference",
+ )
+ parser.add_argument(
+ "--cpu",
+ dest="cpu",
+ action="store_true",
+ help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise",
+ )
+ parser.add_argument(
+ "--output",
+ type=str,
+ help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file",
+ )
+ parser.add_argument(
+ "--hf-token",
+ default="true",
+ type=str,
+ help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
+ )
args = parser.parse_args()
- # Define online speaker diarization pipeline
- config = PipelineConfig.from_dict(vars(args))
- pipeline = OnlineSpeakerDiarization(config)
+ # Resolve device
+ args.device = torch.device("cpu") if args.cpu else None
+
+ # Resolve models
+ hf_token = utils.parse_hf_token_arg(args.hf_token)
+ args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
+ args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
+
+ # Resolve pipeline
+ pipeline_class = utils.get_pipeline_class(args.pipeline)
+ config = pipeline_class.get_config_class()(**vars(args))
+ pipeline = pipeline_class(config)
# Manage audio source
- block_size = config.optimal_block_size()
source_components = args.source.split(":")
if source_components[0] != "microphone":
args.source = Path(args.source).expanduser()
args.output = args.source.parent if args.output is None else Path(args.output)
padding = config.get_file_padding(args.source)
- audio_source = src.FileAudioSource(args.source, config.sample_rate, padding, block_size)
+ audio_source = src.FileAudioSource(
+ args.source, config.sample_rate, padding, config.step
+ )
pipeline.set_timestamp_shift(-padding[0])
else:
- args.output = Path("~/").expanduser() if args.output is None else Path(args.output)
+ args.output = (
+ Path("~/").expanduser() if args.output is None else Path(args.output)
+ )
device = int(source_components[1]) if len(source_components) > 1 else None
- audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device)
+ audio_source = src.MicrophoneAudioSource(config.step, device)
# Run online inference
- inference = RealTimeInference(
+ inference = StreamingInference(
pipeline,
audio_source,
batch_size=1,
@@ -59,8 +132,13 @@ def run():
do_plot=not args.no_plot,
show_progress=True,
)
- inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm"))
- inference()
+ inference.attach_observers(
+ RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")
+ )
+ try:
+ inference()
+ except KeyboardInterrupt:
+ pass
if __name__ == "__main__":
diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py
index 4ad8852a..ec243348 100644
--- a/src/diart/console/tune.py
+++ b/src/diart/console/tune.py
@@ -1,54 +1,145 @@
import argparse
from pathlib import Path
-import diart.argdoc as argdoc
import optuna
-from diart.blocks import PipelineConfig, OnlineSpeakerDiarization
-from diart.optim import Optimizer, HyperParameter
+import torch
from optuna.samplers import TPESampler
+from diart import argdoc
+from diart import models as m
+from diart import utils
+from diart.blocks.base import HyperParameter
+from diart.optim import Optimizer
+
def run():
parser = argparse.ArgumentParser()
- parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
- parser.add_argument("--reference", required=True, type=str,
- help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
- parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
- help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
- parser.add_argument("--embedding", default="pyannote/embedding", type=str,
- help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
- parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
- parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
- parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
- parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3")
- parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1")
- parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3")
- parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
- parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
- parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32")
- parser.add_argument("--cpu", dest="cpu", action="store_true",
- help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
- parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"),
- help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new")
- parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials")
- parser.add_argument("--storage", type=str,
- help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name")
+ parser.add_argument(
+ "root",
+ type=str,
+ help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)",
+ )
+ parser.add_argument(
+ "--reference",
+ required=True,
+ type=str,
+ help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files",
+ )
+ parser.add_argument(
+ "--pipeline",
+ default="SpeakerDiarization",
+ type=str,
+ help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'",
+ )
+ parser.add_argument(
+ "--segmentation",
+ default="pyannote/segmentation",
+ type=str,
+ help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation",
+ )
+ parser.add_argument(
+ "--embedding",
+ default="pyannote/embedding",
+ type=str,
+ help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding",
+ )
+ parser.add_argument(
+ "--duration",
+ type=float,
+ help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
+ )
+ parser.add_argument(
+ "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
+ )
+ parser.add_argument(
+ "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
+ )
+ parser.add_argument(
+ "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
+ )
+ parser.add_argument(
+ "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
+ )
+ parser.add_argument(
+ "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10"
+ )
+ parser.add_argument(
+ "--max-speakers",
+ default=20,
+ type=int,
+ help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20",
+ )
+ parser.add_argument(
+ "--batch-size",
+ default=32,
+ type=int,
+ help=f"{argdoc.BATCH_SIZE}. Defaults to 32",
+ )
+ parser.add_argument(
+ "--cpu",
+ dest="cpu",
+ action="store_true",
+ help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise",
+ )
+ parser.add_argument(
+ "--hparams",
+ nargs="+",
+ default=("tau_active", "rho_update", "delta_new"),
+ help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new",
+ )
+ parser.add_argument(
+ "--num-iter", default=100, type=int, help="Number of optimization trials"
+ )
+ parser.add_argument(
+ "--storage",
+ type=str,
+ help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name",
+ )
parser.add_argument("--output", type=str, help="Working directory")
- parser.add_argument("--hf-token", default="true", type=str,
- help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
+ parser.add_argument(
+ "--hf-token",
+ default="true",
+ type=str,
+ help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
+ )
args = parser.parse_args()
+ # Resolve device
+ args.device = torch.device("cpu") if args.cpu else None
+
+ # Resolve models
+ hf_token = utils.parse_hf_token_arg(args.hf_token)
+ args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
+ args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
+
+ # Retrieve pipeline class
+ pipeline_class = utils.get_pipeline_class(args.pipeline)
+
# Create the base configuration for each trial
- base_config = PipelineConfig.from_dict(vars(args))
+ base_config = pipeline_class.get_config_class()(**vars(args))
# Create hyper-parameters to optimize
+ possible_hparams = pipeline_class.hyper_parameters()
hparams = [HyperParameter.from_name(name) for name in args.hparams]
+ hparams = [hp for hp in hparams if hp in possible_hparams]
+ if not hparams:
+ print(
+ f"No hyper-parameters to optimize. "
+ f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}"
+ )
+ exit(1)
# Use a custom storage if given
if args.output is not None:
msg = "Both `output` and `storage` were set, but only one was expected"
assert args.storage is None, msg
- args.output = Path(args.output)
+ args.output = Path(args.output).expanduser()
args.output.mkdir(parents=True, exist_ok=True)
study_or_path = args.output
elif args.storage is not None:
@@ -60,11 +151,11 @@ def run():
# Run optimization
Optimizer(
+ pipeline_class=pipeline_class,
speech_path=args.root,
reference_path=args.reference,
study_or_path=study_or_path,
batch_size=args.batch_size,
- pipeline_class=OnlineSpeakerDiarization,
hparams=hparams,
base_config=base_config,
)(num_iter=args.num_iter, show_progress=True)
diff --git a/src/diart/features.py b/src/diart/features.py
index 2489027a..2d5df672 100644
--- a/src/diart/features.py
+++ b/src/diart/features.py
@@ -1,4 +1,5 @@
from typing import Union, Optional
+from abc import ABC, abstractmethod
import numpy as np
import torch
@@ -7,15 +8,18 @@
TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor]
-class TemporalFeatureFormatterState:
+class TemporalFeatureFormatterState(ABC):
"""
Represents the recorded type of a temporal feature formatter.
Its job is to transform temporal features into tensors and
recover the original format on other features.
"""
+
+ @abstractmethod
def to_tensor(self, features: TemporalFeatures) -> torch.Tensor:
- raise NotImplementedError
+ pass
+ @abstractmethod
def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
"""
Cast `features` to the representing type and remove batch dimension if required.
@@ -28,7 +32,7 @@ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
-------
new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim)
"""
- raise NotImplementedError
+ pass
class SlidingWindowFeatureFormatterState(TemporalFeatureFormatterState):
@@ -48,7 +52,9 @@ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
# Calculate resolution
resolution = self.duration / num_frames
# Temporal shift to keep track of current start time
- resolution = SlidingWindow(start=self._cur_start_time, duration=resolution, step=resolution)
+ resolution = SlidingWindow(
+ start=self._cur_start_time, duration=resolution, step=resolution
+ )
return SlidingWindowFeature(features.squeeze(dim=0).cpu().numpy(), resolution)
@@ -74,6 +80,7 @@ class TemporalFeatureFormatter:
When casting temporal features as torch.Tensor, it remembers its
type and format so it can lately restore it on other temporal features.
"""
+
def __init__(self):
self.state: Optional[TemporalFeatureFormatterState] = None
diff --git a/src/diart/inference.py b/src/diart/inference.py
index f4b65f5f..3eb72930 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -4,32 +4,33 @@
from traceback import print_exc
from typing import Union, Text, Optional, Callable, Tuple, List
-import diart.operators as dops
-import diart.sources as src
import numpy as np
import pandas as pd
import rx
import rx.operators as ops
import torch
-from diart import utils
-from diart.blocks import BasePipeline, Resample, BasePipelineConfig
-from diart.progress import ProgressBar, RichProgressBar, TQDMProgressBar
-from diart.sinks import DiarizationPredictionAccumulator, RealTimePlot, WindowClosedException
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.database.util import load_rttm
-from pyannote.metrics.diarization import DiarizationErrorRate
+from pyannote.metrics.base import BaseMetric
from rx.core import Observer
from tqdm import tqdm
+from . import blocks
+from . import operators as dops
+from . import sources as src
+from . import utils
+from .progress import ProgressBar, RichProgressBar, TQDMProgressBar
+from .sinks import PredictionAccumulator, StreamingPlot, WindowClosedException
-class RealTimeInference:
+
+class StreamingInference:
"""Performs inference in real time given a pipeline and an audio source.
Streams an audio source to an online speaker diarization pipeline.
It allows users to attach a chain of operations in the form of hooks.
Parameters
----------
- pipeline: BasePipeline
+ pipeline: StreamingPipeline
Configured speaker diarization pipeline.
source: AudioSource
Audio source to be read and streamed.
@@ -50,9 +51,10 @@ class RealTimeInference:
If description is not provided, set to 'Streaming