Skip to content

Commit

Permalink
Added speaker identification script with cosine and neural classifier… (
Browse files Browse the repository at this point in the history
#3672)

* Added speaker identification script with cosine and neural classifier backends

Signed-off-by: nithinraok <[email protected]>

* updated documentation

Signed-off-by: nithinraok <[email protected]>

* typo fixes

Signed-off-by: nithinraok <[email protected]>
  • Loading branch information
nithinraok authored Feb 16, 2022
1 parent 8ffc92e commit 37fe5b4
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 152 deletions.
28 changes: 15 additions & 13 deletions docs/source/asr/speaker_recognition/results.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,31 @@ Where the model base class is the ASR model class of the original checkpoint, or
Speaker Label Inference
------------------------

Speaker label Inference, is to infer speaker labels from a pretrained speaker model with known speaker labels. We provide `speaker_reco_infer.py` script for this purpose under `<NeMo_root>/examples/speaker_tasks/recognition` folder.
The goal of speaker label inference is to infer speaker labels using a speaker model with known speaker labels from enrollment set. We provide `speaker_identification_infer.py` script for this purpose under `<NeMo_root>/examples/speaker_tasks/recognition` folder.
Currently supported backends are cosine_similarity and neural classifier.

The audio files should be 16KHz mono channel wav files.

Write audio files to a ``manifest.json`` file with lines as in format:
The script takes two manifest files:

.. code-block:: json
{"audio_filepath": "<absolute path to dataset>/audio_file.wav", "duration": "duration of file in sec", "label": "UNK"}
This python call will use the pretrain model and infer labels on provided test set using labels from trained manifest file
* enrollment_manifest : This manifest contains enrollment data with known speaker labels.
* test_manifest: This manifest contains test data for which we map speaker labels captured from enrollment manifest using one of provided backend

sample format for each of these manifests is provided in `<NeMo_root>/examples/speaker_tasks/recognition/conf/speaker_identification_infer.yaml` config file.

To infer speaker labels using cosine_similarity backend

.. code-block:: bash
python speaker_reco_infer.py --spkr_model='/path/to/.nemo/file' --train_manifest='/path/to/train/manifest/file'
--test_manifest=/path/to/train/manifest/file' --batch_size=32
python speaker_identification_infer.py data.enrollment_manifest=<path/to/enrollment_manifest> data.test_manifest=<path/to/test_manifest> backend.backend_model=cosine_similarity
Speaker Embedding Extraction
-----------------------------
Speaker Embedding Extraction, is to extract speaker embeddings for any wav file (from known or unknown speakers). We provide two ways to do this:

* single python liner for extracting embeddings from a single file
* python script for extracting embeddings from a bunch of files provided through manifest file
* single Python liner for extracting embeddings from a single file
* Python script for extracting embeddings from a bunch of files provided through manifest file

For extracting embeddings from a single file:

Expand Down Expand Up @@ -101,14 +103,14 @@ The SpeakerNet-ASR collection has checkpoints of several models trained on vario
The tables below list the speaker embedding extractor models available from NGC, and the models can be accessed via the
:code:`from_pretrained()` method inside the EncDecSpeakerLabelModel Model class.

In general, you can load any of these models with code in the following format.
In general, you can load any of these models with code in the following format:

.. code-block:: python
import nemo.collections.asr as nemo_asr
model = nemo_asr.models.<MODEL_CLASS_NAME>.from_pretrained(model_name="<MODEL_NAME>")
Where the model name is the value under "Model Name" entry in the tables below.
where the model name is the value under "Model Name" entry in the tables below.

If you would like to programatically list the models available for a particular base class, you can use the
:code:`list_available_models()` method.
Expand Down
9 changes: 5 additions & 4 deletions examples/speaker_tasks/recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ We provide generic scripts for manifest file creation, embedding extraction, Vox
We explain here the process for voxceleb EER calculation on voxceleb-O cleaned [trail file](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt)

### Manifest Creation
We first generate manifest file to get embeddings, The embeddings are then used by `voxceleb_eval.py` script to get EER
We first generate manifest file to get embeddings. The embeddings are then used by `voxceleb_eval.py` script to get EER

```bash
# create list of files from voxceleb1 test folder (40 speaker test set)
Expand All @@ -70,12 +70,13 @@ python voxceleb_eval.py --trial_file='/path/to/trail/file' --emb='./embeddings/v
The above command gives the performance of models on voxceleb-o cleaned trial file.

### SpeakerID inference
Using data from an enrollment set, one can infer labels on a test set using various backends such as cosine-similarity or a neural classifier.

To infer speaker labels on a model trained with known speaker labels (or fine tuned using pretrained model)
To infer speaker labels using cosine_similarity backend
```bash
python speaker_reco_infer.py --spkr_model='/path/to/.nemo/file' --train_manifest='/path/to/train/manifest/file'
--test_manifest='/path/to/test/manifest/file'
python speaker_identification_infer.py data.enrollment_manifest=<path/to/enrollment_manifest> data.test_manifest=<path/to/test_manifest> backend.backend_model=cosine_similarity
```
refer to conf/speaker_identification_infer.yaml for more options.

## Voxceleb Data Preparation

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: &name "SpeakerIdentificationInfer"

data:
enrollment_manifest: ???
test_manifest: ???
out_manifest: './infer_output.json'
sample_rate: 16000

backend:
backend_model: cosine_similarity # supported backends are cosine_similarity and neural_classifier

cosine_similarity:
model_path: titanet_large # or path to .nemo file
batch_size: 32

neural_classifier:
model_path: ??? # path to neural model trained/finetuned with enrollment dataset
batch_size: 32

# json manifest line example
#
# enrollment_manifest:
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "<speaker_label>"}
#
# test_manifest:
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer"}
#
114 changes: 114 additions & 0 deletions examples/speaker_tasks/recognition/speaker_identification_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

import numpy as np
import torch
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything

from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataset
from nemo.collections.asr.models import EncDecSpeakerLabelModel
from nemo.collections.asr.parts.features import WaveformFeaturizer
from nemo.core.config import hydra_runner
from nemo.utils import logging

seed_everything(42)


@hydra_runner(config_path="conf", config_name="speaker_identification_infer")
def main(cfg):

logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

enrollment_manifest = cfg.data.enrollment_manifest
test_manifest = cfg.data.test_manifest
out_manifest = cfg.data.out_manifest
sample_rate = cfg.data.sample_rate

backend = cfg.backend.backend_model.lower()

if backend == 'cosine_similarity':
model_path = cfg.backend.cosine_similarity.model_path
batch_size = cfg.backend.cosine_similarity.batch_size
if model_path.endswith('.nemo'):
speaker_model = EncDecSpeakerLabelModel.restore_from(model_path)
else:
speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path)

enroll_embs, _, enroll_truelabels, enroll_id2label = EncDecSpeakerLabelModel.get_batch_embeddings(
speaker_model, enrollment_manifest, batch_size, sample_rate, device=device,
)

test_embs, _, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings(
speaker_model, test_manifest, batch_size, sample_rate, device=device,
)

# length normalize
enroll_embs = enroll_embs / (np.linalg.norm(enroll_embs, ord=2, axis=-1, keepdims=True))
test_embs = test_embs / (np.linalg.norm(test_embs, ord=2, axis=-1, keepdims=True))

# reference embedding
reference_embs = []
keyslist = list(enroll_id2label.keys())
for label_id in keyslist:
indices = np.where(enroll_truelabels == label_id)
embedding = (enroll_embs[indices].sum(axis=0).squeeze()) / len(indices)
reference_embs.append(embedding)

reference_embs = np.asarray(reference_embs)

scores = np.matmul(test_embs, reference_embs.T)
matched_labels = scores.argmax(axis=-1)

elif backend == 'neural_classifier':
model_path = cfg.backend.neural_classifier.model_path
batch_size = cfg.backend.neural_classifier.batch_size

if model_path.endswith('.nemo'):
speaker_model = EncDecSpeakerLabelModel.restore_from(model_path)
else:
speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path)

featurizer = WaveformFeaturizer(sample_rate=sample_rate)
dataset = AudioToSpeechLabelDataset(manifest_filepath=enrollment_manifest, labels=None, featurizer=featurizer)
enroll_id2label = dataset.id2label

if speaker_model.decoder.final.out_features != len(enroll_id2label):
raise ValueError(
"number of labels mis match. Make sure you trained or finetuned neural classifier with labels from enrollement manifest_filepath"
)

_, test_logits, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings(
speaker_model, test_manifest, batch_size, sample_rate, device=device,
)
matched_labels = test_logits.argmax(axis=-1)

with open(test_manifest, 'rb') as f1, open(out_manifest, 'w', encoding='utf-8') as f2:
lines = f1.readlines()
for idx, line in enumerate(lines):
line = line.strip()
item = json.loads(line)
item['infer'] = enroll_id2label[matched_labels[idx]]
json.dump(item, f2)
f2.write('\n')

logging.info("Inference labels have been written to {} manifest file".format(out_manifest))


if __name__ == '__main__':
main()
135 changes: 0 additions & 135 deletions examples/speaker_tasks/recognition/speaker_reco_infer.py

This file was deleted.

Loading

0 comments on commit 37fe5b4

Please sign in to comment.