Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.8 #192

Merged
merged 14 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 31 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
</a>
<span> | </span>
<a href="#-custom-models">
🤖 Custom models
🤖 Add your model
</a>
<span> | </span>
<a href="#-tune-hyper-parameters">
Expand Down Expand Up @@ -64,17 +64,11 @@
1) Create environment:

```shell
conda create -n diart python=3.8
conda env create -f diart/environment.yml
conda activate diart
```

2) Install audio libraries:

```shell
conda install portaudio pysoundfile ffmpeg -c conda-forge
```

3) Install diart:
2) Install the package:
```shell
pip install diart
```
Expand Down Expand Up @@ -110,32 +104,32 @@ See `diart.stream -h` for more options.

### From python

Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk:
Use `StreamingInference` to run a pipeline on an audio source and write the results to disk:

```python
from diart import OnlineSpeakerDiarization
from diart import SpeakerDiarization
from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference
from diart.inference import StreamingInference
from diart.sinks import RTTMWriter

pipeline = OnlineSpeakerDiarization()
mic = MicrophoneAudioSource(pipeline.config.sample_rate)
inference = RealTimeInference(pipeline, mic, do_plot=True)
pipeline = SpeakerDiarization()
mic = MicrophoneAudioSource()
inference = StreamingInference(pipeline, mic, do_plot=True)
inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
prediction = inference()
```

For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).
For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#-reproducibility)).

## 🤖 Custom models
## 🤖 Add your model

Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses):
Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):

```python
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import EmbeddingModel, SegmentationModel
from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference
from diart.inference import StreamingInference


def model_loader():
Expand Down Expand Up @@ -168,19 +162,19 @@ class MyEmbeddingModel(EmbeddingModel):
return self.model(waveform, weights)


config = PipelineConfig(
config = SpeakerDiarizationConfig(
segmentation=MySegmentationModel(),
embedding=MyEmbeddingModel()
)
pipeline = OnlineSpeakerDiarization(config)
mic = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference(pipeline, mic)
pipeline = SpeakerDiarization(config)
mic = MicrophoneAudioSource()
inference = StreamingInference(pipeline, mic)
prediction = inference()
```

## 📈 Tune hyper-parameters

Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.
Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs.

### From the command line

Expand Down Expand Up @@ -247,12 +241,11 @@ from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding

segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding")
sample_rate = segmentation.model.sample_rate
mic = MicrophoneAudioSource(sample_rate)
mic = MicrophoneAudioSource()

stream = mic.stream.pipe(
# Reformat stream to 5s duration and 500ms shift
dops.rearrange_audio_stream(sample_rate=sample_rate),
dops.rearrange_audio_stream(sample_rate=segmentation.model.sample_rate),
ops.map(lambda wav: (wav, segmentation(wav))),
ops.starmap(embedding)
).subscribe(on_next=lambda emb: print(emb.shape))
Expand Down Expand Up @@ -281,7 +274,7 @@ diart.serve --host 0.0.0.0 --port 7007
diart.client microphone --host <server-address> --port 7007
```

**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.

See `-h` for more options.

Expand All @@ -290,13 +283,13 @@ See `-h` for more options.
For customized solutions, a server can also be created in python using the `WebSocketAudioSource`:

```python
from diart import OnlineSpeakerDiarization
from diart import SpeakerDiarization
from diart.sources import WebSocketAudioSource
from diart.inference import RealTimeInference
from diart.inference import StreamingInference

pipeline = OnlineSpeakerDiarization()
pipeline = SpeakerDiarization()
source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
inference = RealTimeInference(pipeline, source)
inference = StreamingInference(pipeline, source)
inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
prediction = inference()
```
Expand Down Expand Up @@ -347,21 +340,21 @@ To obtain the best results, make sure to use the following hyper-parameters:
`diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration:

```shell
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021
diart.benchmark /wav/dir --reference /rttm/dir --tau-active=0.555 --rho-update=0.422 --delta-new=1.517 --segmentation pyannote/segmentation@Interspeech2021
```

or using the inference API:

```python
from diart.inference import Benchmark, Parallelize
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import SegmentationModel

benchmark = Benchmark("/wav/dir", "/rttm/dir")

name = "pyannote/segmentation@Interspeech2021"
segmentation = SegmentationModel.from_pyannote(name)
config = PipelineConfig(
config = SpeakerDiarizationConfig(
# Set the model used in the paper
segmentation=segmentation,
step=0.5,
Expand All @@ -370,12 +363,12 @@ config = PipelineConfig(
rho_update=0.422,
delta_new=1.517
)
benchmark(OnlineSpeakerDiarization, config)
benchmark(SpeakerDiarization, config)

# Run the same benchmark in parallel
p_benchmark = Parallelize(benchmark, num_workers=4)
if __name__ == "__main__": # Needed for multiprocessing
p_benchmark(OnlineSpeakerDiarization, config)
p_benchmark(SpeakerDiarization, config)
```

This pre-calculates model outputs in batches, so it runs a lot faster.
Expand Down
12 changes: 12 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: diart
channels:
- conda-forge
- defaults
dependencies:
- python=3.8
- portaudio=19.6.*
- pysoundfile=0.12.*
- ffmpeg[version='<4.4']
- pip
- pip:
- .
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ tqdm>=4.64.0
pandas>=1.4.2
torch>=1.12.1
torchvision>=0.14.0
torchaudio>=0.12.1,<1.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
pyannote.core>=4.5
pyannote.database>=4.1.1
Expand Down
8 changes: 4 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[metadata]
name=diart
version=0.7.0
version=0.8.0
author=Juan Manuel Coria
description=Speaker diarization in real time
description=Streaming speaker diarization in real-time
long_description=file: README.md
long_description_content_type=text/markdown
keywords=speaker diarization, streaming, online, real time, rxpy
url=https://github.com/juanmc2005/StreamingSpeakerDiarization
url=https://github.com/juanmc2005/diart
license=MIT
classifiers=
Development Status :: 4 - Beta
Expand All @@ -30,7 +30,7 @@ install_requires=
pandas>=1.4.2
torch>=1.12.1
torchvision>=0.14.0
torchaudio>=0.12.1,<1.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
pyannote.core>=4.5
pyannote.database>=4.1.1
Expand Down
8 changes: 5 additions & 3 deletions src/diart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .blocks import (
OnlineSpeakerDiarization,
BasePipeline,
SpeakerDiarization,
Pipeline,
SpeakerDiarizationConfig,
PipelineConfig,
BasePipelineConfig,
VoiceActivityDetection,
VoiceActivityDetectionConfig,
)
1 change: 1 addition & 0 deletions src/diart/argdoc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
SEGMENTATION = "Segmentation model name from pyannote"
EMBEDDING = "Embedding model name from pyannote"
DURATION = "Chunk duration (in seconds)"
STEP = "Sliding window step (in seconds)"
LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION"
TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1"
Expand Down
5 changes: 3 additions & 2 deletions src/diart/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
OverlapAwareSpeakerEmbedding,
)
from .segmentation import SpeakerSegmentation
from .diarization import OnlineSpeakerDiarization, BasePipeline
from .config import BasePipelineConfig, PipelineConfig
from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
from .base import PipelineConfig, Pipeline
from .utils import Binarize, Resample, AdjustVolume
from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig
Loading