Skip to content

Commit

Permalink
Fix wrong label mapping in batch_inference for label_model (#5767) (#…
Browse files Browse the repository at this point in the history
…5870)

* fix batch inference



* add test for batch



* fix device

Signed-off-by: fayejf <[email protected]>
Co-authored-by: fayejf <[email protected]>
  • Loading branch information
github-actions[bot] and fayejf authored Jan 27, 2023
1 parent e1b3f5e commit 4ae3ed5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
4 changes: 3 additions & 1 deletion nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,9 @@ def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, d
mapped_labels = list(mapped_labels)

featurizer = WaveformFeaturizer(sample_rate=sample_rate)
dataset = AudioToSpeechLabelDataset(manifest_filepath=manifest_filepath, labels=None, featurizer=featurizer)
dataset = AudioToSpeechLabelDataset(
manifest_filepath=manifest_filepath, labels=mapped_labels, featurizer=featurizer
)

dataloader = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, collate_fn=dataset.fixed_seq_collate_fn,
Expand Down
22 changes: 22 additions & 0 deletions tests/collections/asr/test_speaker_label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
from unittest import TestCase

import pytest
import torch
from omegaconf import DictConfig

from nemo.collections.asr.models import EncDecSpeakerLabelModel
Expand Down Expand Up @@ -170,3 +173,22 @@ def test_pretrained_ambernet_logits(self, test_data_dir):
label = lang_model.get_label(filename)

assert label == "en"

@pytest.mark.unit
def test_pretrained_ambernet_logits_batched(self, test_data_dir):
model_name = 'langid_ambernet'
lang_model = EncDecSpeakerLabelModel.from_pretrained(model_name)
relative_filepath = "an4_speaker/an4/wav/an4_clstk/fash/an255-fash-b.wav"
filename = os.path.join(test_data_dir, relative_filepath)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

with tempfile.TemporaryDirectory() as tmpdir:
temp_manifest = os.path.join(tmpdir, 'manifest.json')
with open(temp_manifest, 'w', encoding='utf-8') as fp:
entry = {"audio_filepath": filename, "duration": 4.5, "label": 'en'}
fp.write(json.dumps(entry) + '\n')

embs, logits, gt_labels, mapped_labels = lang_model.batch_inference(temp_manifest, device=device)
pred_label = mapped_labels[logits.argmax(axis=-1)[0]]
true_label = mapped_labels[gt_labels[0]]
assert pred_label == true_label

0 comments on commit 4ae3ed5

Please sign in to comment.