Skip to content

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmc2005 committed Nov 10, 2023
1 parent 864f366 commit 0dd2bba
Showing 1 changed file with 29 additions and 35 deletions.
64 changes: 29 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,55 +123,48 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n

## 🤖 Add your model

Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):
Third-party models can be integrated by providing a loader function:

```python
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import EmbeddingModel, SegmentationModel, SpeakerTrackingModel
from diart.models import EmbeddingModel, SegmentationModel
from diart.sources import MicrophoneAudioSource
from diart.inference import StreamingInference


def model_loader():
def segmentation_loader():
# It should take a waveform and return a segmentation tensor
return load_pretrained_model("my_model.ckpt")

def embedding_loader():
# It should take (waveform, weights) and return per-speaker embeddings
return load_pretrained_model("my_other_model.ckpt")

class MySegmentationModel(SegmentationModel):
def __init__(self):
super().__init__(model_loader)

@property
def sample_rate(self) -> int:
return 16000

@property
def duration(self) -> float:
return 2 # seconds

def forward(self, waveform):
# self.model is created lazily
return self.model(waveform)


class MyEmbeddingModel(EmbeddingModel):
def __init__(self):
super().__init__(model_loader)

def forward(self, waveform, weights):
# self.model is created lazily
return self.model(waveform, weights)


model = SpeakerTrackingModel.from_models(
MySegmentationModel(), MyEmbeddingModel()
segmentation = SegmentationModel(segmentation_loader)
embedding = EmbeddingModel(embedding_loader)
config = SpeakerDiarizationConfig(
segmentation=segmentation,
embedding=embedding,
)
config = SpeakerDiarizationConfig(model)
pipeline = SpeakerDiarization(config)
mic = MicrophoneAudioSource()
inference = StreamingInference(pipeline, mic)
prediction = inference()
```

If you have an ONNX model, you can use `from_onnx()`:

```python
from diart.models import EmbeddingModel

embedding = EmbeddingModel.from_onnx(
model_path="my_model.ckpt",
input_names=["x", "w"], # defaults to ["waveform", "weights"]
output_name="output", # defaults to "embedding"
)
```

## 📈 Tune hyper-parameters

Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs.
Expand Down Expand Up @@ -348,14 +341,15 @@ or using the inference API:
```python
from diart.inference import Benchmark, Parallelize
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import SpeakerTrackingModel
from diart.models import SegmentationModel

benchmark = Benchmark("/wav/dir", "/rttm/dir")

model = SpeakerTrackingModel.from_asru2021()
model_name = "pyannote/segmentation@Interspeech2021"
model = SegmentationModel.from_pretrained(model_name)
config = SpeakerDiarizationConfig(
# Set the model used in the paper
model=model,
# Set the segmentation model used in the paper
segmentation=model,
step=0.5,
latency=0.5,
tau_active=0.555,
Expand Down

0 comments on commit 0dd2bba

Please sign in to comment.