diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 000c9a976c2d..0191fbd0ba00 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -558,7 +558,9 @@ def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, d mapped_labels = list(mapped_labels) featurizer = WaveformFeaturizer(sample_rate=sample_rate) - dataset = AudioToSpeechLabelDataset(manifest_filepath=manifest_filepath, labels=None, featurizer=featurizer) + dataset = AudioToSpeechLabelDataset( + manifest_filepath=manifest_filepath, labels=mapped_labels, featurizer=featurizer + ) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, collate_fn=dataset.fixed_seq_collate_fn, diff --git a/tests/collections/asr/test_speaker_label_models.py b/tests/collections/asr/test_speaker_label_models.py index 92c0d3a53e90..8de3c5bc0813 100644 --- a/tests/collections/asr/test_speaker_label_models.py +++ b/tests/collections/asr/test_speaker_label_models.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import tempfile from unittest import TestCase import pytest +import torch from omegaconf import DictConfig from nemo.collections.asr.models import EncDecSpeakerLabelModel @@ -170,3 +173,22 @@ def test_pretrained_ambernet_logits(self, test_data_dir): label = lang_model.get_label(filename) assert label == "en" + + @pytest.mark.unit + def test_pretrained_ambernet_logits_batched(self, test_data_dir): + model_name = 'langid_ambernet' + lang_model = EncDecSpeakerLabelModel.from_pretrained(model_name) + relative_filepath = "an4_speaker/an4/wav/an4_clstk/fash/an255-fash-b.wav" + filename = os.path.join(test_data_dir, relative_filepath) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + with tempfile.TemporaryDirectory() as tmpdir: + temp_manifest = os.path.join(tmpdir, 'manifest.json') + with open(temp_manifest, 'w', encoding='utf-8') as fp: + entry = {"audio_filepath": filename, "duration": 4.5, "label": 'en'} + fp.write(json.dumps(entry) + '\n') + + embs, logits, gt_labels, mapped_labels = lang_model.batch_inference(temp_manifest, device=device) + pred_label = mapped_labels[logits.argmax(axis=-1)[0]] + true_label = mapped_labels[gt_labels[0]] + assert pred_label == true_label