Skip to content

Commit

Permalink
Merge pull request #66 from juanmc2005/develop
Browse files Browse the repository at this point in the history
Version 0.4
  • Loading branch information
juanmc2005 authored Jul 13, 2022
2 parents 0ebf729 + b40f091 commit 17b29b2
Show file tree
Hide file tree
Showing 25 changed files with 1,809 additions and 1,079 deletions.
213 changes: 175 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,49 @@
<img alt="License" src="https://img.shields.io/github/license/juanmc2005/StreamingSpeakerDiarization?color=g">
</p>

<div align="center">
<h4>
<a href="#installation">
Installation
</a>
<span> | </span>
<a href="#stream-audio">
Stream audio
</a>
<span> | </span>
<a href="#add-your-model">
Add your model
</a>
<span> | </span>
<a href="#tune-hyper-parameters">
Tune hyper-parameters
</a>
<span> | </span>
<a href="#build-pipelines">
Build pipelines
</a>
<br/>
<a href="#powered-by-research">
Research
</a>
<span> | </span>
<a href="#citation">
Citation
</a>
<span> | </span>
<a href="#reproducibility">
Reproducibility
</a>
</h4>
</div>

<br/>

<p align="center">
<img width="100%" src="/demo.gif" title="Real-time diarization example" />
</p>

## Install
## Installation

1) Create environment:

Expand All @@ -27,85 +63,186 @@ conda create -n diart python=3.8
conda activate diart
```

2) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)
2) Install `PortAudio` and `soundfile`:

```shell
conda install portaudio
conda install pysoundfile -c conda-forge
```

3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)

4) Install pyannote.audio 2.0 (currently in development)

3) Install pyannote.audio 2.0 (currently in development)
```shell
pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
```

4) Install diart:
**Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored.

5) Install diart:
```shell
pip install diart
```

## Stream your own audio
## Stream audio

### A recorded conversation
### From the command line

A recorded conversation:

```shell
python -m diart.stream /path/to/audio.wav
diart.stream /path/to/audio.wav
```

### From your microphone
A live conversation:

```shell
python -m diart.stream microphone
diart.stream microphone
```

See `python -m diart.stream -h` for more options.
See `diart.stream -h` for more options.

## Inference API
### From python

Run a customized real-time speaker diarization pipeline over an audio stream with `diart.inference.RealTimeInference`:
Run a real-time speaker diarization pipeline over an audio stream with `RealTimeInference`:

```python
from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig

pipeline = OnlineSpeakerDiarization(PipelineConfig())
audio_source = MicrophoneAudioSource(pipeline.sample_rate)
config = PipelineConfig() # Default parameters
pipeline = OnlineSpeakerDiarization(config)
audio_source = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference("/output/path", do_plot=True)

inference(pipeline, audio_source)
```

For faster inference and evaluation on a dataset we recommend to use `Benchmark` (see our notes on [reproducibility](#reproducibility))
For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)).

## Add your model

Third-party segmentation and embedding models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:

```python
import torch
from typing import Optional
from diart.models import EmbeddingModel
from diart.pipelines import PipelineConfig, OnlineSpeakerDiarization
from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference

class MyEmbeddingModel(EmbeddingModel):
def __init__(self):
super().__init__()
self.my_pretrained_model = load("my_model.ckpt")

def __call__(
self,
waveform: torch.Tensor,
weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
return self.my_pretrained_model(waveform, weights)

config = PipelineConfig(embedding=MyEmbeddingModel())
pipeline = OnlineSpeakerDiarization(config)
mic = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference("/out/dir")
inference(pipeline, mic)
```

## Tune hyper-parameters

Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.

### From the command line

```shell
diart.tune /wav/dir --reference /rttm/dir --output /out/dir
```

## Build your own pipeline
See `diart.tune -h` for more options.

Diart also provides building blocks that can be combined to create your own pipeline.
### From python

```python
from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew
from diart.pipelines import PipelineConfig
from diart.inference import Benchmark

# Benchmark runs and evaluates the pipeline on a dataset
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir/tmp", show_report=False)
# Base configuration for the pipeline we're going to tune
base_config = PipelineConfig()
# Hyper-parameters to optimize
hparams = [TauActive, RhoUpdate, DeltaNew]
# Optimizer implements the optimization loop
optimizer = Optimizer(benchmark, base_config, hparams, "/out/dir")
# Run optimization
optimizer.optimize(num_iter=100, show_progress=True)
```

This will use `/out/dir/tmp` as a working directory and write results to an sqlite database in `/out/dir`.

### Distributed optimization

For bigger datasets, it is sometimes more convenient to run multiple optimization processes in parallel.
To do this, create a study on a [recommended DBMS](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py) (e.g. MySQL or PostgreSQL) making sure that the study and database names match:

```shell
mysql -u root -e "CREATE DATABASE IF NOT EXISTS example"
optuna create-study --study-name "example" --storage "mysql://root@localhost/example"
```

Then you can run multiple identical optimizers pointing to the database:

```shell
diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example
```

If you are using the python API, make sure that worker directories are different to avoid concurrency issues:

```python
from diart.optim import Optimizer
from diart.inference import Benchmark
from optuna.samplers import TPESampler
import optuna

ID = 0 # Worker identifier
base_config, hparams = ...
benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker-{ID}", show_report=False)
study = optuna.load_study("example", "mysql://root@localhost/example", TPESampler())
optimizer = Optimizer(benchmark, base_config, hparams, study)
optimizer.optimize(num_iter=100, show_progress=True)
```

## Build pipelines

For a more advanced usage, diart also provides building blocks that can be combined to create your own pipeline.
Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately.

### Example

Obtain overlap-aware speaker embeddings from a microphone stream:

```python
import rx
import rx.operators as ops
import diart.operators as dops
from diart.sources import MicrophoneAudioSource
from diart.blocks import FramewiseModel, OverlapAwareSpeakerEmbedding
from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding

sample_rate = 16000
segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding")
sample_rate = segmentation.model.get_sample_rate()
mic = MicrophoneAudioSource(sample_rate)

# Initialize independent modules
segmentation = FramewiseModel("pyannote/segmentation")
embedding = OverlapAwareSpeakerEmbedding("pyannote/embedding")

# Reformat microphone stream. Defaults to 5s duration and 500ms shift
regular_stream = mic.stream.pipe(dops.regularize_stream(sample_rate))
# Branch the microphone stream to calculate segmentation
segmentation_stream = regular_stream.pipe(ops.map(segmentation))
# Join audio and segmentation stream to calculate speaker embeddings
embedding_stream = rx.zip(
regular_stream, segmentation_stream
).pipe(ops.starmap(embedding))

embedding_stream.subscribe(on_next=lambda emb: print(emb.shape))
stream = mic.stream.pipe(
# Reformat stream to 5s duration and 500ms shift
dops.regularize_audio_stream(sample_rate),
ops.map(lambda wav: (wav, segmentation(wav))),
ops.starmap(embedding)
).subscribe(on_next=lambda emb: print(emb.shape))

mic.read()
```
Expand Down Expand Up @@ -165,7 +302,7 @@ To obtain the best results, make sure to use the following hyper-parameters:
`diart.benchmark` and `diart.inference.Benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration:

```shell
python -m diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
```

or using the inference API:
Expand All @@ -184,11 +321,11 @@ config = PipelineConfig(
pipeline = OnlineSpeakerDiarization(config)
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir")

benchmark(pipeline, batch_size=32)
benchmark(pipeline)
```

This runs a faster inference by pre-calculating model outputs in batches.
See `python -m diart.benchmark -h` for more options.
See `diart.benchmark -h` for more options.

For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.

Expand Down
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@ sounddevice>=0.4.2
einops>=0.3.0
tqdm>=4.64.0
pandas>=1.4.2
git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
torchaudio>=0.10,<1.0
pyannote.core>=4.4
pyannote.database>=4.1.1
pyannote.metrics>=3.2
optuna>=2.10
19 changes: 14 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name=diart
version=0.3.0
version=0.4.0
author=Juan Manuel Coria
description=Speaker diarization in real time
long_description=file: README.md
Expand All @@ -9,7 +9,7 @@ keywords=speaker diarization, streaming, online, real time, rxpy
url=https://github.com/juanmc2005/StreamingSpeakerDiarization
license=MIT
classifiers=
Development Status :: 3 - Alpha
Development Status :: 4 - Beta
License :: OSI Approved :: MIT License
Topic :: Multimedia :: Sound/Audio :: Analysis
Topic :: Multimedia :: Sound/Audio :: Speech
Expand All @@ -19,7 +19,7 @@ classifiers=
package_dir=
=src
packages=find:
install_requires =
install_requires=
numpy>=1.20.2
matplotlib>=3.3.3
rx>=3.2.0
Expand All @@ -28,8 +28,17 @@ install_requires =
einops>=0.3.0
tqdm>=4.64.0
pandas>=1.4.2
pyannote-audio @ git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio

torchaudio>=0.10,<1.0
pyannote.core>=4.4
pyannote.database>=4.1.1
pyannote.metrics>=3.2
optuna>=2.10

[options.packages.find]
where=src

[options.entry_points]
console_scripts=
diart.stream=diart.stream:run
diart.benchmark=diart.benchmark:run
diart.tune=diart.tune:run
3 changes: 2 additions & 1 deletion src/diart/argdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
GAMMA = "Parameter gamma for overlapped speech penalty"
BETA = "Parameter beta for overlapped speech penalty"
MAX_SPEAKERS = "Maximum number of speakers"
GPU = "Run on GPU"
CPU = "Force models to run on CPU"
BATCH_SIZE = "For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency"
OUTPUT = "Directory to store the system's output in RTTM format"
Loading

0 comments on commit 17b29b2

Please sign in to comment.