diff --git a/examples/speaker_recognition/speaker_reco.py b/examples/speaker_recognition/speaker_reco.py index 3259a6512b56..0a30a93a2160 100644 --- a/examples/speaker_recognition/speaker_reco.py +++ b/examples/speaker_recognition/speaker_reco.py @@ -24,22 +24,20 @@ """ Basic run (on CPU for 50 epochs): - python examples/speaker_recognition/speaker_reco.py \ - model.train_ds.manifest_filepath="" \ - model.validation_ds.manifest_filepath="" \ - hydra.run.dir="." \ - trainer.gpus=0 \ - trainer.max_epochs=50 - +EXP_NAME=sample_run +python ./speaker_reco.py --config-path='conf' --config-name='config.yaml' \ + trainer.max_epochs=10 \ + model.train_ds.batch_size=64 model.validation_ds.batch_size=64 \ + trainer.gpus=0 \ + model.decoder.params.num_classes=2 \ + exp_manager.name=$EXP_NAME +exp_manager.use_datetime_version=False \ + exp_manager.exp_dir='./speaker_exps' Add PyTorch Lightning Trainer arguments from CLI: python speaker_reco.py \ ... \ +trainer.fast_dev_run=true -Hydra logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/.hydra)" -PTL logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/lightning_logs)" - """ seed_everything(42) diff --git a/examples/speaker_recognition/spkr_get_emb.py b/examples/speaker_recognition/spkr_get_emb.py index de4887c0a262..218c23817ca8 100644 --- a/examples/speaker_recognition/spkr_get_emb.py +++ b/examples/speaker_recognition/spkr_get_emb.py @@ -23,23 +23,13 @@ from nemo.utils.exp_manager import exp_manager """ -Basic run (on CPU for 50 epochs): - python examples/speaker_recognition/speaker_reco.py \ - model.train_ds.manifest_filepath="" \ - model.validation_ds.manifest_filepath="" \ +To extract embeddings + python examples/speaker_recognition/spkr_get_emb.py \ + model.test_ds.manifest_filepath="" \ + exp_manager.exp_name="" + exp_manager.exp_dir="" hydra.run.dir="." \ - trainer.gpus=0 \ - trainer.max_epochs=50 - - -Add PyTorch Lightning Trainer arguments from CLI: - python speaker_reco.py \ - ... \ - +trainer.fast_dev_run=true - -Hydra logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/.hydra)" -PTL logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/lightning_logs)" - + trainer.gpus=1 """ seed_everything(42) @@ -49,12 +39,15 @@ def main(cfg): logging.info(f'Hydra config: {cfg.pretty()}') - trainer = pl.Trainer(logger=False, checkpoint_callback=False) + if cfg.trainer.gpus > 1: + logging.info("changing gpus to 1 to minimize DDP issues while extracting embeddings") + cfg.trainer.gpus = 1 + cfg.trainer.distributed_backend = None + trainer = pl.Trainer(**cfg.trainer) log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) model_path = os.path.join(log_dir, '..', 'spkr.nemo') speaker_model = ExtractSpeakerEmbeddingsModel.restore_from(model_path) speaker_model.setup_test_data(cfg.model.test_ds) - trainer.test(speaker_model) diff --git a/nemo/collections/asr/data/audio_to_label.py b/nemo/collections/asr/data/audio_to_label.py index 3d35b45611db..ac5aed51f819 100644 --- a/nemo/collections/asr/data/audio_to_label.py +++ b/nemo/collections/asr/data/audio_to_label.py @@ -96,6 +96,7 @@ def __init__( self.trim = trim self.load_audio = load_audio self.time_length = time_length + logging.info("Timelength considered for collate func is {}".format(time_length)) self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) @@ -120,9 +121,9 @@ def fixed_seq_collate_fn(self, batch): _, audio_lengths, _, tokens_lengths = zip(*batch) has_audio = audio_lengths[0] is not None - fixed_length = min(fixed_length, max(audio_lengths)) + fixed_length = int(min(fixed_length, max(audio_lengths))) - audio_signal, tokens = [], [] + audio_signal, tokens, new_audio_lengths = [], [], [] for sig, sig_len, tokens_i, _ in batch: if has_audio: sig_len = sig_len.item() @@ -134,17 +135,19 @@ def fixed_seq_collate_fn(self, batch): sub = sig[-rem:] if rem > 0 else torch.tensor([]) rep_sig = torch.cat(repeat * [sig]) signal = torch.cat((rep_sig, sub)) + new_audio_lengths.append(torch.tensor(fixed_length)) else: start_idx = torch.randint(0, chunck_len, (1,)) if chunck_len else torch.tensor(0) end_idx = start_idx + fixed_length signal = sig[start_idx:end_idx] + new_audio_lengths.append(torch.tensor(fixed_length)) audio_signal.append(signal) tokens.append(tokens_i) if has_audio: audio_signal = torch.stack(audio_signal) - audio_lengths = torch.stack(audio_lengths) + audio_lengths = torch.stack(new_audio_lengths) else: audio_signal, audio_lengths = None, None tokens = torch.stack(tokens) @@ -152,6 +155,59 @@ def fixed_seq_collate_fn(self, batch): return audio_signal, audio_lengths, tokens, tokens_lengths + def sliced_seq_collate_fn(self, batch): + """collate batch of audio sig, audio len, tokens, tokens len + Args: + batch (Optional[FloatTensor], Optional[LongTensor], LongTensor, + LongTensor): A tuple of tuples of signal, signal lengths, + encoded tokens, and encoded tokens length. This collate func + assumes the signals are 1d torch tensors (i.e. mono audio). + fixed_length (Optional[int]): length of input signal to be considered + """ + slice_length = self.featurizer.sample_rate * self.time_length + _, audio_lengths, _, tokens_lengths = zip(*batch) + slice_length = min(slice_length, max(audio_lengths)) + shift = 1 * self.featurizer.sample_rate + has_audio = audio_lengths[0] is not None + + audio_signal, num_slices, tokens, audio_lengths = [], [], [], [] + for sig, sig_len, tokens_i, _ in batch: + if has_audio: + sig_len = sig_len.item() + slices = sig_len // slice_length + if slices <= 0: + + repeat = slice_length // sig_len + rem = slice_length % sig_len + sub = sig[-rem:] if rem > 0 else torch.tensor([]) + rep_sig = torch.cat(repeat * [sig]) + signal = torch.cat((rep_sig, sub)) + audio_signal.append(signal) + num_slices.append(1) # single embedding + tokens.extend([tokens_i] * 1) + audio_lengths.extend([slice_length] * 1) + else: + slices = (sig_len - slice_length) // shift + 1 + for slice_id in range(slices): + start_idx = slice_id * shift + end_idx = start_idx + slice_length + signal = sig[start_idx:end_idx] + audio_signal.append(signal) + + num_slices.append(slices) + tokens.extend([tokens_i] * slices) + audio_lengths.extend([slice_length] * slices) + + if has_audio: + audio_signal = torch.stack(audio_signal) + audio_lengths = torch.tensor(audio_lengths) + else: + audio_signal, audio_lengths = None, None + tokens = torch.stack(tokens) + tokens_lengths = torch.tensor(num_slices) # each embedding length + + return audio_signal, audio_lengths, tokens, tokens_lengths + def __len__(self): return len(self.collection) diff --git a/nemo/collections/asr/losses/angularloss.py b/nemo/collections/asr/losses/angularloss.py new file mode 100644 index 000000000000..e2aee9bba6ea --- /dev/null +++ b/nemo/collections/asr/losses/angularloss.py @@ -0,0 +1,68 @@ +# ! /usr/bin/python +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.core.classes import Loss, Typing, typecheck +from nemo.core.neural_types import LabelsType, LogitsType, LossType, NeuralType + +__all__ = ['AngularSoftmaxLoss'] + + +class AngularSoftmaxLoss(Loss, Typing): + """ + Computes ArcFace Angular softmax angle loss + reference: https://openaccess.thecvf.com/content_CVPR_2019/papers/Deng_ArcFace_Additive_Angular_Margin_Loss_for_Deep_Face_Recognition_CVPR_2019_paper.pdf + args: + scale: scale value for cosine angle + margin: margin value added to cosine angle + """ + + @property + def input_types(self): + """Input types definitions for AnguarLoss. + """ + return { + "logits": NeuralType(('B', 'D'), LogitsType()), + "labels": NeuralType(('B',), LabelsType()), + } + + @property + def output_types(self): + """Output types definitions for AngularLoss. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, scale=20.0, margin=1.35): + super().__init__() + + self.eps = 1e-7 + self.scale = scale + self.margin = margin + + @typecheck() + def forward(self, logits, labels): + numerator = self.scale * torch.cos( + torch.acos(torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps)) + + self.margin + ) + excl = torch.cat( + [torch.cat((logits[i, :y], logits[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0 + ) + denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * excl), dim=1) + L = numerator - torch.log(denominator) + return -torch.mean(L) diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index e6eb1961a0c5..9837d1066e15 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -17,20 +17,20 @@ import pickle as pkl from typing import Dict, Optional, Union -import numpy as np import torch from omegaconf import DictConfig from pytorch_lightning import Trainer from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataSet +from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss from nemo.collections.asr.parts.features import WaveformFeaturizer from nemo.collections.asr.parts.perturb import process_augmentations from nemo.collections.common.losses import CrossEntropyLoss as CELoss +from nemo.collections.common.metrics import TopKClassificationAccuracy, compute_topk_accuracy from nemo.core.classes import ModelPT from nemo.core.classes.common import typecheck from nemo.core.neural_types import * from nemo.utils import logging -from nemo.utils.decorators import experimental __all__ = ['EncDecSpeakerLabelModel', 'ExtractSpeakerEmbeddingsModel'] @@ -50,7 +50,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor) self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder) self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder) - self.loss = CELoss() + if 'angular' in cfg.decoder.params and cfg.decoder.params['angular']: + logging.info("Training with Angular Softmax Loss") + scale = cfg.loss.scale + margin = cfg.loss.margin + self.loss = AngularSoftmaxLoss(scale=scale, margin=margin) + else: + logging.info("Training with Softmax-CrossEntropy loss") + self.loss = CELoss() + + self._accuracy = TopKClassificationAccuracy(top_k=[1]) def __setup_dataloader_from_config(self, config: Optional[Dict]): if 'augmentor' in config: @@ -139,38 +148,48 @@ def forward(self, input_signal, input_signal_length): return logits, embs # PTL-specific methods - def training_step(self, batch, batch_nb): + def training_step(self, batch, batch_idx): audio_signal, audio_signal_len, labels, _ = batch logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) - loss_value = self.loss(logits=logits, labels=labels) - labels_hat = torch.argmax(logits, dim=1) - n_correct_pred = torch.sum(labels == labels_hat, dim=0).item() - tensorboard_logs = {'train_loss': loss_value, 'training_batch_acc': (n_correct_pred / len(labels)) * 100} + self.loss_value = self.loss(logits=logits, labels=labels) + + tensorboard_logs = { + 'train_loss': self.loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + } - return {'loss': loss_value, 'log': tensorboard_logs, "n_correct_pred": n_correct_pred, "n_pred": len(labels)} + correct_counts, total_counts = self._accuracy(logits=logits, labels=labels) - def training_epoch_end(self, outputs): - train_acc = (sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs)) * 100 - tensorboard_logs = {'train_acc': train_acc} + for ki in range(correct_counts.shape[-1]): + correct_count = correct_counts[ki] + total_count = total_counts[ki] + top_k = self._accuracy.top_k[ki] + self.accuracy = (correct_count / float(total_count)) * 100 - return {'train_acc': train_acc, 'log': tensorboard_logs} + tensorboard_logs['training_batch_accuracy_top@{}'.format(top_k)] = self.accuracy - def validation_step(self, batch, batch_idx): + return {'loss': self.loss_value, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx: int = 0): audio_signal, audio_signal_len, labels, _ = batch logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) - loss_value = self.loss(logits=logits, labels=labels) - labels_hat = torch.argmax(logits, dim=1) - n_correct_pred = torch.sum(labels == labels_hat, dim=0).item() + self.loss_value = self.loss(logits=logits, labels=labels) + correct_counts, total_counts = self._accuracy(logits=logits, labels=labels) + return {'val_loss': self.loss_value, 'val_correct_counts': correct_counts, 'val_total_counts': total_counts} - return {'val_loss': loss_value, "n_correct_pred": n_correct_pred, "n_pred": len(labels)} + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + self.val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + correct_counts = torch.stack([x['val_correct_counts'] for x in outputs]) + total_counts = torch.stack([x['val_total_counts'] for x in outputs]) - def validation_epoch_end(self, outputs): - val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() - val_acc = (sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs)) * 100 - logging.info("validation accuracy {:.3f}".format(val_acc)) - tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_acc': val_acc} + topk_scores = compute_topk_accuracy(correct_counts, total_counts) + logging.info("val_loss: {:.3f}".format(self.val_loss_mean)) + tensorboard_log = {'val_loss': self.val_loss_mean} + for top_k, score in zip(self._accuracy.top_k, topk_scores): + tensorboard_log['val_epoch_top@{}'.format(top_k)] = score + self.accuracy = score * 100 - return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + return {'log': tensorboard_log} def test_step(self, batch, batch_ix): audio_signal, audio_signal_len, labels, _ = batch @@ -188,16 +207,17 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) def test_step(self, batch, batch_ix): - audio_signal, audio_signal_len, labels, _ = batch + audio_signal, audio_signal_len, labels, slices = batch _, embs = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) - return {'embs': embs, 'labels': labels} + return {'embs': embs, 'labels': labels, 'slices': slices} def test_epoch_end(self, outputs): embs = torch.cat([x['embs'] for x in outputs]) + slices = torch.cat([x['slices'] for x in outputs]) emb_shape = embs.shape[-1] embs = embs.view(-1, emb_shape).cpu().numpy() out_embeddings = {} - + start_idx = 0 with open(self.test_manifest, 'r') as manifest: for idx, line in enumerate(manifest.readlines()): line = line.strip() @@ -206,7 +226,10 @@ def test_epoch_end(self, outputs): uniq_name = '@'.join(structure) if uniq_name in out_embeddings: raise KeyError("Embeddings for label {} already present in emb dictionary".format(uniq_name)) - out_embeddings[uniq_name] = embs[idx] + num_slices = slices[idx] + end_idx = start_idx + num_slices + out_embeddings[uniq_name] = embs[start_idx:end_idx].mean(axis=0) + start_idx = end_idx embedding_dir = os.path.join(self.embedding_dir, 'embeddings') if not os.path.exists(embedding_dir): diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 28b3431d5ce6..8ff66e5a347a 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from omegaconf import ListConfig, OmegaConf from nemo.collections.asr.parts.jasper import ( @@ -36,7 +37,6 @@ SpectrogramType, ) from nemo.utils import logging -from nemo.utils.decorators import experimental __all__ = ['ConvASRDecoder', 'ConvASREncoder', 'ConvASRDecoderClassification'] @@ -356,12 +356,20 @@ def output_types(self): ) def __init__( - self, feat_in, num_classes, emb_sizes=[1024, 1024], pool_mode='xvector', init_mode="xavier_uniform", + self, feat_in, num_classes, emb_sizes=None, pool_mode='xvector', angular=False, init_mode="xavier_uniform", ): super().__init__() + self.angular = angular + self.emb_id = 2 + if self.angular: + bias = False + else: + bias = True if type(emb_sizes) is str: emb_sizes = emb_sizes.split(',') + elif emb_sizes == None: + emb_sizes = [512, 512] else: emb_sizes = list(emb_sizes) @@ -380,7 +388,7 @@ def __init__( self.emb_layers = nn.ModuleList(emb_layers) - self.final = nn.Linear(shapes[-1], self._num_classes) + self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias) self.apply(lambda x: init_weights(x, mode=init_mode)) @@ -399,9 +407,14 @@ def forward(self, encoder_output): embs = [] for layer in self.emb_layers: - pool, emb = layer(pool), layer[:2](pool) + pool, emb = layer(pool), layer[: self.emb_id](pool) embs.append(emb) + if self.angular: + for W in self.final.parameters(): + W = F.normalize(W, p=2, dim=1) + pool = F.normalize(pool, p=2, dim=1) + out = self.final(pool) return out, embs[-1]