Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.5 #87

Merged
merged 26 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a828661
Update README.md
juanmc2005 Jul 26, 2022
ea609fa
Update README.md
juanmc2005 Jul 26, 2022
a7befd7
add `study_or_path` as a Path for conversion from string
AMITKESARI2000 Jul 26, 2022
b6e048c
Add WebSocketAudioSource
juanmc2005 Jul 27, 2022
7a4114f
Add WebSocket section to README.md
juanmc2005 Jul 27, 2022
2262b04
Replace uri to avoid path error
ckliao-nccu Jul 28, 2022
7223b54
Update README.md
ckliao-nccu Jul 28, 2022
a5948dc
Update README.md
juanmc2005 Jul 29, 2022
9294a06
Update README.md
juanmc2005 Jul 29, 2022
e0ebb96
Make RealTimeInference compatible with websockets. RealTimeInference …
juanmc2005 Jul 29, 2022
1b2b289
Greatly simplify the optim API by setting sensible defaults
juanmc2005 Jul 29, 2022
97a5b59
Add on-the-fly resampling
juanmc2005 Aug 8, 2022
f1aa182
Add method to convert SpeakerMap into a dictionary. Bug fixes and docs
juanmc2005 Aug 16, 2022
6bbc226
Fix bug with empty RTTMs (#81)
zaouk Aug 17, 2022
d24c0e8
Merge newest changes from develop
juanmc2005 Aug 17, 2022
a4742f4
Add SetVolume block to change the volume of audio chunks
juanmc2005 Aug 17, 2022
0ef3a11
Fix inverted decibels in SetVolume
juanmc2005 Aug 17, 2022
b66e05c
Rename SetVolume to AdjustVolume
juanmc2005 Aug 17, 2022
ca15311
Add diart.stream arguments to change pyannote models
juanmc2005 Aug 23, 2022
9880003
Add model arguments in diart.benchmark and diart.tune. Other improvem…
juanmc2005 Aug 26, 2022
00c4936
Rename RTTMAccumulator to DiarizationPredictionAccumulator
juanmc2005 Aug 30, 2022
fd237c4
Improve websocket section in README and clarify a TODO comment
juanmc2005 Aug 31, 2022
dc5e73e
Merge pull request #77 from juanmc2005/feat/ws
juanmc2005 Aug 31, 2022
e4a11b0
Export csv report in diart.benchmark when output is provided
juanmc2005 Aug 31, 2022
a978bbb
Resolve conflicts between develop and main
juanmc2005 Aug 31, 2022
b75dc9f
Change version to 0.5.0
juanmc2005 Aug 31, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 59 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
Stream audio
</a>
<span> | </span>
<a href="#add-your-model">
Add your model
<a href="#custom-models">
Custom models
</a>
<span> | </span>
<a href="#tune-hyper-parameters">
Expand All @@ -34,6 +34,10 @@
Build pipelines
</a>
<br/>
<a href="#websockets">
WebSockets
</a>
<span> | </span>
<a href="#powered-by-research">
Research
</a>
Expand Down Expand Up @@ -72,10 +76,10 @@ conda install pysoundfile -c conda-forge

3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)

4) Install pyannote.audio 2.0 (currently no official release)
4) Install pyannote.audio

```shell
pip install git+https://github.com/pyannote/pyannote-audio.git@2.0.1#egg=pyannote-audio
pip install pyannote.audio==2.0.1
```

**Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored.
Expand Down Expand Up @@ -105,25 +109,26 @@ See `diart.stream -h` for more options.

### From python

Run a real-time speaker diarization pipeline over an audio stream with `RealTimeInference`:
Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk:

```python
from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig

config = PipelineConfig() # Default parameters
pipeline = OnlineSpeakerDiarization(config)
audio_source = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference("/output/path", do_plot=True)
inference(pipeline, audio_source)
from diart.pipelines import OnlineSpeakerDiarization
from diart.sinks import RTTMWriter

pipeline = OnlineSpeakerDiarization()
mic = MicrophoneAudioSource(pipeline.config.sample_rate)
inference = RealTimeInference(pipeline, mic, do_plot=True)
inference.attach_observers(RTTMWriter("/output/file.rttm"))
inference()
```

For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)).
For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).

## Add your model
## Custom models

Third-party segmentation and embedding models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:
Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:

```python
import torch
Expand All @@ -148,8 +153,8 @@ class MyEmbeddingModel(EmbeddingModel):
config = PipelineConfig(embedding=MyEmbeddingModel())
pipeline = OnlineSpeakerDiarization(config)
mic = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference("/out/dir")
inference(pipeline, mic)
inference = RealTimeInference(pipeline, mic)
inference()
```

## Tune hyper-parameters
Expand All @@ -159,31 +164,21 @@ Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.re
### From the command line

```shell
diart.tune /wav/dir --reference /rttm/dir --output /out/dir
diart.tune /wav/dir --reference /rttm/dir --output /output/dir
```

See `diart.tune -h` for more options.

### From python

```python
from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew
from diart.pipelines import PipelineConfig
from diart.inference import Benchmark
from diart.optim import Optimizer

# Benchmark runs and evaluates the pipeline on a dataset
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir/tmp", show_report=False)
# Base configuration for the pipeline we're going to tune
base_config = PipelineConfig()
# Hyper-parameters to optimize
hparams = [TauActive, RhoUpdate, DeltaNew]
# Optimizer implements the optimization loop
optimizer = Optimizer(benchmark, base_config, hparams, "/out/dir")
# Run optimization
optimizer.optimize(num_iter=100, show_progress=True)
optimizer = Optimizer("/wav/dir", "/rttm/dir", "/output/dir")
optimizer(num_iter=100)
```

This will use `/out/dir/tmp` as a working directory and write results to an sqlite database in `/out/dir`.
This will write results to an sqlite database in `/output/dir`.

### Distributed optimization

Expand All @@ -195,26 +190,23 @@ mysql -u root -e "CREATE DATABASE IF NOT EXISTS example"
optuna create-study --study-name "example" --storage "mysql://root@localhost/example"
```

Then you can run multiple identical optimizers pointing to the database:
You can now run multiple identical optimizers pointing to this database:

```shell
diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example
diart.tune /wav/dir --reference /rttm/dir --storage mysql://root@localhost/example
```

If you are using the python API, make sure that worker directories are different to avoid concurrency issues:
or in python:

```python
from diart.optim import Optimizer
from diart.inference import Benchmark
from optuna.samplers import TPESampler
import optuna

ID = 0 # Worker identifier
base_config, hparams = ...
benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker-{ID}", show_report=False)
study = optuna.load_study("example", "mysql://root@localhost/example", TPESampler())
optimizer = Optimizer(benchmark, base_config, hparams, study)
optimizer.optimize(num_iter=100, show_progress=True)
db = "mysql://root@localhost/example"
study = optuna.load_study("example", db, TPESampler())
optimizer = Optimizer("/wav/dir", "/rttm/dir", study)
optimizer(num_iter=100)
```

## Build pipelines
Expand Down Expand Up @@ -256,6 +248,24 @@ torch.Size([4, 512])
...
```

## WebSockets

Diart is also compatible with the WebSocket protocol to serve pipelines on the web.

In the following example we build a minimal server that receives audio chunks and sends back predictions in RTTM format:

```python
from diart.pipelines import OnlineSpeakerDiarization
from diart.sources import WebSocketAudioSource
from diart.inference import RealTimeInference

pipeline = OnlineSpeakerDiarization()
source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
inference = RealTimeInference(pipeline, source, do_plot=True)
inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
inference()
```

## Powered by research

Diart is the official implementation of the paper *[Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation](/paper.pdf)* by [Juan Manuel Coria](https://juanmc2005.github.io/), [Hervé Bredin](https://herve.niderb.fr), [Sahar Ghannay](https://saharghannay.github.io/) and [Sophie Rosset](https://perso.limsi.fr/rosset/).
Expand Down Expand Up @@ -299,32 +309,34 @@ To obtain the best results, make sure to use the following hyper-parameters:
| DIHARD II | 1s | 0.619 | 0.326 | 0.997 |
| DIHARD II | 5s | 0.555 | 0.422 | 1.517 |

`diart.benchmark` and `diart.inference.Benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration:
`diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration:

```shell
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021
```

or using the inference API:

```python
from diart.inference import Benchmark
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
from diart.models import SegmentationModel

config = PipelineConfig(
# Set the model used in the paper
segmentation=SegmentationModel.from_pyannote("pyannote/segmentation@Interspeech2021"),
step=0.5,
latency=0.5,
tau_active=0.555,
rho_update=0.422,
delta_new=1.517
)
pipeline = OnlineSpeakerDiarization(config)
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir")

benchmark = Benchmark("/wav/dir", "/rttm/dir")
benchmark(pipeline)
```

This runs a faster inference by pre-calculating model outputs in batches.
This pre-calculates model outputs in batches, so it runs a lot faster.
See `diart.benchmark -h` for more options.

For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ einops>=0.3.0
tqdm>=4.64.0
pandas>=1.4.2
torchaudio>=0.10,<1.0
pyannote.core>=4.4
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
optuna>=2.10
websockets>=10.3
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name=diart
version=0.4.0
version=0.5.0
author=Juan Manuel Coria
description=Speaker diarization in real time
long_description=file: README.md
Expand Down Expand Up @@ -29,10 +29,11 @@ install_requires=
tqdm>=4.64.0
pandas>=1.4.2
torchaudio>=0.10,<1.0
pyannote.core>=4.4
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
optuna>=2.10
websockets>=10.3

[options.packages.find]
where=src
Expand Down
2 changes: 2 additions & 0 deletions src/diart/argdoc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
SEGMENTATION = "Segmentation model name from pyannote"
EMBEDDING = "Embedding model name from pyannote"
STEP = "Sliding window step (in seconds)"
LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION"
TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1"
Expand Down
21 changes: 16 additions & 5 deletions src/diart/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import argparse
from pathlib import Path

import torch

import diart.argdoc as argdoc
from diart.inference import Benchmark
from diart.models import SegmentationModel, EmbeddingModel
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig


def run():
parser = argparse.ArgumentParser()
parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
parser.add_argument("--reference", type=str,
parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
parser.add_argument("--embedding", default="pyannote/embedding", type=str,
help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
parser.add_argument("--reference", type=Path,
help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
Expand All @@ -23,20 +29,25 @@ def run():
parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32")
parser.add_argument("--cpu", dest="cpu", action="store_true",
help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing")
args = parser.parse_args()
args.device = torch.device("cpu") if args.cpu else None
args.segmentation = SegmentationModel.from_pyannote(args.segmentation)
args.embedding = EmbeddingModel.from_pyannote(args.embedding)

benchmark = Benchmark(
args.root,
args.reference,
args.output,
show_progress=True,
show_report=True,
batch_size=args.batch_size
batch_size=args.batch_size,
)

benchmark(OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True))
pipeline = OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True)
report = benchmark(pipeline)
if args.output is not None and report is not None:
report.to_csv(args.output / "benchmark_report.csv")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/diart/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
OverlapAwareSpeakerEmbedding,
)
from .segmentation import SpeakerSegmentation
from .utils import Binarize
from .utils import Binarize, Resample, AdjustVolume
2 changes: 1 addition & 1 deletion src/diart/blocks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
with torch.no_grad():
wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample")
output = self.model(wave.to(self.device)).cpu()
return self.formatter.restore_type(output)
return self.formatter.restore_type(output)
Loading