Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.3.0 #56

Merged
merged 42 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d499ef0
Python 3.7 compatibility (#29)
Yagna24 Apr 26, 2022
935b630
Start refactoring for batched diarization pipeline
juanmc2005 Apr 22, 2022
68126c8
Batchify FrameWiseModel and ChunkWiseModel
juanmc2005 Apr 22, 2022
5ac42e4
Add batched pipeline implementation
juanmc2005 Apr 25, 2022
6aeab04
Move pre-calculated pipeline to OnlineSpeakerDiarization.from_file()
juanmc2005 Apr 25, 2022
46dc353
Add argument to skip plotting for faster inference in demo script
juanmc2005 Apr 25, 2022
53f0f04
Remove empty line
juanmc2005 Apr 26, 2022
7bb7cda
Add benchmark script. Add optional verbosity to from_file(). Add tqdm…
juanmc2005 Apr 26, 2022
5089d52
Dumb down PipelineConfig. Make sample rate completely depend on the s…
juanmc2005 Apr 27, 2022
e02ace3
Fix segmentation resolution not being adapted to chunk duration
juanmc2005 Apr 27, 2022
ecaea24
Add DER evaluation to benchmark script. Add FileAudioSource parameter…
juanmc2005 Apr 27, 2022
63525df
Add optional processing time profiling in FileAudioSource
juanmc2005 Apr 27, 2022
b15db1b
Add GPU support in demo and benchmarking
juanmc2005 Apr 27, 2022
dad7153
Make reference optional in benchmarking script
juanmc2005 Apr 27, 2022
8287de9
Calculate number of chunks from duration instead of samples in ChunkL…
juanmc2005 May 2, 2022
a19041c
Fix bug in batched pipeline: an edge case was causing the batch dimen…
juanmc2005 May 3, 2022
d8362fb
Fix bug in from_file(): segmentation and embedding remove batch dimen…
juanmc2005 May 3, 2022
4861578
Fix end time bug in batched pipeline
juanmc2005 May 4, 2022
1830d22
Centralize stream end time calculation
juanmc2005 May 4, 2022
24dd009
Add diart.benchmark in readme file
juanmc2005 May 4, 2022
8a80c2d
Add pyannote.metrics performance report in diart.benchmark
juanmc2005 May 4, 2022
b3dfebe
Add progress bar to demo script
juanmc2005 May 4, 2022
7ecae25
Fix method docstring
juanmc2005 May 4, 2022
accf0e7
Merge branch 'main' of github.com:juanmc2005/StreamingSpeakerDiarizat…
juanmc2005 May 8, 2022
5cdeea2
Add tqdm requirement to setup.cfg
juanmc2005 May 8, 2022
aa8de08
Rename demo.py to stream.py. Unify script argument docs. Update READM…
juanmc2005 May 8, 2022
1fb18ca
Rename functional.py to blocks.py. Rename FrameWiseModel and ChunkWis…
juanmc2005 May 8, 2022
9488368
Update README.md
juanmc2005 May 8, 2022
bdcb242
Add OverlapAwareSpeakerEmbedding block
juanmc2005 May 9, 2022
0510d5a
Update README.md
juanmc2005 May 9, 2022
c943ef3
Add Benchmark class
juanmc2005 May 12, 2022
f0de692
Add RealTimeInference class. Minor bug fix in buffer_output()
juanmc2005 May 12, 2022
091cbe2
Add 'inference' module containing RealTimeInference and Benchmark
juanmc2005 May 12, 2022
db3fd9d
Add docstrings to RealTimeInference and Benchmark
juanmc2005 May 13, 2022
f450455
Use last incomplete chunk with padding
juanmc2005 May 14, 2022
dee15c4
Improve benchmark progress logging
juanmc2005 May 18, 2022
13d52c2
Add docstring correction
juanmc2005 May 20, 2022
cae4580
Add inference API to README.md
juanmc2005 May 20, 2022
d3130ac
Add demo gif combining snippet and visualization
juanmc2005 May 20, 2022
9c7ef15
Update README.md
juanmc2005 May 20, 2022
8b42d1e
Remove unused files
juanmc2005 May 20, 2022
17c125d
Update setup.cfg
juanmc2005 May 20, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 32 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,26 @@
<img width="51%" src="/visualization.gif" title="Real-time diarization example" />
</p>

## Demo

You can visualize the real-time speaker diarization of an audio stream with the built-in demo script.
## Getting started

### Stream a recorded conversation

```shell
python -m diart.demo /path/to/audio.wav
python -m diart.stream /path/to/audio.wav
```

### Stream from your microphone

```shell
python -m diart.demo microphone
python -m diart.stream microphone
```

See `python -m diart.demo -h` for more information.
See `python -m diart.stream -h` for more options.

## Build your own pipeline

Diart provides building blocks that can be combined to do speaker diarization on an audio stream.
The streaming implementation is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `functional` module is completely independent.
Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately.

### Example

Expand All @@ -48,29 +46,25 @@ Obtain overlap-aware speaker embeddings from a microphone stream
```python
import rx
import rx.operators as ops
import diart.operators as myops
import diart.operators as dops
from diart.sources import MicrophoneAudioSource
import diart.functional as fn
from diart.blocks import FramewiseModel, OverlapAwareSpeakerEmbedding

sample_rate = 16000
mic = MicrophoneAudioSource(sample_rate)

# Initialize independent modules
segmentation = fn.FrameWiseModel("pyannote/segmentation")
embedding = fn.ChunkWiseModel("pyannote/embedding")
osp = fn.OverlappedSpeechPenalty(gamma=3, beta=10)
normalization = fn.EmbeddingNormalization(norm=1)
segmentation = FramewiseModel("pyannote/segmentation")
embedding = OverlapAwareSpeakerEmbedding("pyannote/embedding")

# Reformat microphone stream. Defaults to 5s duration and 500ms shift
regular_stream = mic.stream.pipe(myops.regularize_stream(sample_rate))
regular_stream = mic.stream.pipe(dops.regularize_stream(sample_rate))
# Branch the microphone stream to calculate segmentation
segmentation_stream = regular_stream.pipe(ops.map(segmentation))
# Join audio and segmentation stream to calculate speaker embeddings
embedding_stream = rx.zip(regular_stream, segmentation_stream).pipe(
ops.starmap(lambda wave, seg: (wave, osp(seg))),
ops.starmap(embedding),
ops.map(normalization)
)
embedding_stream = rx.zip(
regular_stream, segmentation_stream
).pipe(ops.starmap(embedding))

embedding_stream.subscribe(on_next=lambda emb: print(emb.shape))

Expand All @@ -91,11 +85,11 @@ torch.Size([4, 512])
1) Create environment:

```shell
conda create -n diarization python==3.8
conda activate diarization
conda create -n diart python=3.8
conda activate diart
```

2) Install the latest PyTorch version following the [official instructions](https://pytorch.org/get-started/locally/#start-locally)
2) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)

3) Install pyannote.audio 2.0 (currently in development)
```shell
Expand Down Expand Up @@ -137,38 +131,29 @@ If you found diart useful, please make sure to cite our paper:

![Results table](/table1.png)

To reproduce the results of the paper, use the following hyper-parameters:
Diart aims to be lightweight and capable of real-time streaming in practical scenarios.
Its performance is very close to what is reported in the paper (and sometimes even a bit better).

To obtain the best results, make sure to use the following hyper-parameters:

Dataset | latency | tau | rho | delta
------------|---------|--------|--------|------
DIHARD III | any | 0.555 | 0.422 | 1.517
AMI | any | 0.507 | 0.006 | 1.057
VoxConverse | any | 0.576 | 0.915 | 0.648
DIHARD II | 1s | 0.619 | 0.326 | 0.997
DIHARD II | 5s | 0.555 | 0.422 | 1.517
| Dataset | latency | tau | rho | delta |
|-------------|---------|--------|--------|-------|
| DIHARD III | any | 0.555 | 0.422 | 1.517 |
| AMI | any | 0.507 | 0.006 | 1.057 |
| VoxConverse | any | 0.576 | 0.915 | 0.648 |
| DIHARD II | 1s | 0.619 | 0.326 | 0.997 |
| DIHARD II | 5s | 0.555 | 0.422 | 1.517 |

For instance, for a DIHARD III configuration:
`diart.benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration:

```shell
python -m diart.demo /path/to/file.wav --tau=0.555 --rho=0.422 --delta=1.517 --output /output/dir
python -m diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
```

And then to obtain the diarization error rate:

```python
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.database.util import load_rttm

metric = DiarizationErrorRate()
hypothesis = load_rttm("/output/dir/output.rttm")
hypothesis = list(hypothesis.values())[0] # Extract hypothesis from dictionary
reference = load_rttm("/path/to/reference.rttm")
reference = list(reference.values())[0] # Extract reference from dictionary

der = metric(reference, hypothesis)
```
`diart.benchmark` runs a faster inference and evaluation by pre-calculating model outputs in batches.
See `python -m diart.benchmark -h` for more options.

For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) in RTTM format for every entry of Table 1 and Figure 5 in the paper. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.
For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.

![Figure 5](/figure5.png)

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ rx>=3.2.0
scipy>=1.6.0
sounddevice>=0.4.2
einops>=0.3.0
tqdm>=4.64.0
pandas>=1.4.2
git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ install_requires =
scipy>=1.6.0
sounddevice>=0.4.2
einops>=0.3.0
tqdm>=4.64.0
pyannote-audio @ git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio


Expand Down
10 changes: 10 additions & 0 deletions src/diart/argdoc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
STEP = "Sliding window step (in seconds)"
LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION"
TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1"
RHO = "Speech ratio threshold to decide if centroids are updated with a given speaker. 0 <= RHO <= 1"
DELTA = "Embedding-to-centroid distance threshold to flag a speaker as known or new. 0 <= DELTA <= 2"
GAMMA = "Parameter gamma for overlapped speech penalty"
BETA = "Parameter beta for overlapped speech penalty"
MAX_SPEAKERS = "Maximum number of speakers"
GPU = "Run on GPU"
OUTPUT = "Directory to store the system's output in RTTM format"
43 changes: 43 additions & 0 deletions src/diart/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import argparse

import torch

import diart.argdoc as argdoc
from diart.inference import Benchmark
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig

if __name__ == "__main__":
# Define script arguments
parser = argparse.ArgumentParser()
parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
parser.add_argument("--reference", type=str, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3")
parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1")
parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3")
parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
parser.add_argument("--batch-size", default=32, type=int, help="For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency. Defaults to 32")
parser.add_argument("--gpu", dest="gpu", action="store_true", help=argdoc.GPU)
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
args = parser.parse_args()

# Set benchmark configuration
benchmark = Benchmark(args.root, args.reference, args.output)

# Define online speaker diarization pipeline
pipeline = OnlineSpeakerDiarization(PipelineConfig(
step=args.step,
latency=args.latency,
tau_active=args.tau,
rho_update=args.rho,
delta_new=args.delta,
gamma=args.gamma,
beta=args.beta,
max_speakers=args.max_speakers,
device=torch.device("cuda") if args.gpu else None,
))

benchmark(pipeline, args.batch_size)
Loading