Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Angular loss1.0 #1101

Merged
merged 14 commits into from
Sep 4, 2020
18 changes: 8 additions & 10 deletions examples/speaker_recognition/speaker_reco.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,20 @@

"""
Basic run (on CPU for 50 epochs):
python examples/speaker_recognition/speaker_reco.py \
model.train_ds.manifest_filepath="<train_manifest_file>" \
model.validation_ds.manifest_filepath="<validation_manifest_file>" \
hydra.run.dir="." \
trainer.gpus=0 \
trainer.max_epochs=50

EXP_NAME=sample_run
python ./speaker_reco.py --config-path='conf' --config-name='config.yaml' \
trainer.max_epochs=10 \
model.train_ds.batch_size=64 model.validation_ds.batch_size=64 \
trainer.gpus=0 \
fayejf marked this conversation as resolved.
Show resolved Hide resolved
model.decoder.params.num_classes=2 \
exp_manager.name=$EXP_NAME +exp_manager.use_datetime_version=False \
exp_manager.exp_dir='./speaker_exps'

Add PyTorch Lightning Trainer arguments from CLI:
python speaker_reco.py \
... \
+trainer.fast_dev_run=true

Hydra logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/.hydra)"
PTL logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/lightning_logs)"

"""

seed_everything(42)
Expand Down
29 changes: 11 additions & 18 deletions examples/speaker_recognition/spkr_get_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,13 @@
from nemo.utils.exp_manager import exp_manager

"""
Basic run (on CPU for 50 epochs):
python examples/speaker_recognition/speaker_reco.py \
model.train_ds.manifest_filepath="<train_manifest_file>" \
model.validation_ds.manifest_filepath="<validation_manifest_file>" \
To extract embeddings
python examples/speaker_recognition/spkr_get_emb.py \
model.test_ds.manifest_filepath="<validation_manifest_file>" \
exp_manager.exp_name="<trained_model_name>"
exp_manager.exp_dir="<path to model chckpoint directories>"
hydra.run.dir="." \
trainer.gpus=0 \
trainer.max_epochs=50


Add PyTorch Lightning Trainer arguments from CLI:
python speaker_reco.py \
... \
+trainer.fast_dev_run=true

Hydra logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/.hydra)"
PTL logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/lightning_logs)"

trainer.gpus=1
"""

seed_everything(42)
Expand All @@ -49,12 +39,15 @@
def main(cfg):

logging.info(f'Hydra config: {cfg.pretty()}')
trainer = pl.Trainer(logger=False, checkpoint_callback=False)
if cfg.trainer.gpus > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait do this only during inference (trainer.test()) otherwise you can't use multi GPU training

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spkr_get_emb.py is only run for inference purposes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok.

logging.info("changing gpus to 1 to minimize DDP issues while extracting embeddings")
cfg.trainer.gpus = 1
cfg.trainer.distributed_backend = None
trainer = pl.Trainer(**cfg.trainer)
log_dir = exp_manager(trainer, cfg.get("exp_manager", None))
model_path = os.path.join(log_dir, '..', 'spkr.nemo')
speaker_model = ExtractSpeakerEmbeddingsModel.restore_from(model_path)
speaker_model.setup_test_data(cfg.model.test_ds)

trainer.test(speaker_model)
titu1994 marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
62 changes: 59 additions & 3 deletions nemo/collections/asr/data/audio_to_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
self.trim = trim
self.load_audio = load_audio
self.time_length = time_length
logging.info("Timelength considered for collate func is {}".format(time_length))

self.labels = labels if labels else self.collection.uniq_labels
self.num_classes = len(self.labels)
Expand All @@ -120,9 +121,9 @@ def fixed_seq_collate_fn(self, batch):
_, audio_lengths, _, tokens_lengths = zip(*batch)

has_audio = audio_lengths[0] is not None
fixed_length = min(fixed_length, max(audio_lengths))
fixed_length = int(min(fixed_length, max(audio_lengths)))

audio_signal, tokens = [], []
audio_signal, tokens, new_audio_lengths = [], [], []
for sig, sig_len, tokens_i, _ in batch:
if has_audio:
sig_len = sig_len.item()
Expand All @@ -134,24 +135,79 @@ def fixed_seq_collate_fn(self, batch):
sub = sig[-rem:] if rem > 0 else torch.tensor([])
rep_sig = torch.cat(repeat * [sig])
signal = torch.cat((rep_sig, sub))
new_audio_lengths.append(torch.tensor(fixed_length))
fayejf marked this conversation as resolved.
Show resolved Hide resolved
else:
start_idx = torch.randint(0, chunck_len, (1,)) if chunck_len else torch.tensor(0)
end_idx = start_idx + fixed_length
signal = sig[start_idx:end_idx]
new_audio_lengths.append(torch.tensor(fixed_length))

audio_signal.append(signal)
tokens.append(tokens_i)

if has_audio:
audio_signal = torch.stack(audio_signal)
audio_lengths = torch.stack(audio_lengths)
audio_lengths = torch.stack(new_audio_lengths)
else:
audio_signal, audio_lengths = None, None
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)

return audio_signal, audio_lengths, tokens, tokens_lengths

def sliced_seq_collate_fn(self, batch):
"""collate batch of audio sig, audio len, tokens, tokens len
Args:
batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
LongTensor): A tuple of tuples of signal, signal lengths,
encoded tokens, and encoded tokens length. This collate func
assumes the signals are 1d torch tensors (i.e. mono audio).
fixed_length (Optional[int]): length of input signal to be considered
"""
slice_length = self.featurizer.sample_rate * self.time_length
_, audio_lengths, _, tokens_lengths = zip(*batch)
slice_length = min(slice_length, max(audio_lengths))
shift = 1 * self.featurizer.sample_rate
has_audio = audio_lengths[0] is not None

audio_signal, num_slices, tokens, audio_lengths = [], [], [], []
for sig, sig_len, tokens_i, _ in batch:
if has_audio:
sig_len = sig_len.item()
slices = sig_len // slice_length
if slices <= 0:

repeat = slice_length // sig_len
rem = slice_length % sig_len
sub = sig[-rem:] if rem > 0 else torch.tensor([])
rep_sig = torch.cat(repeat * [sig])
signal = torch.cat((rep_sig, sub))
audio_signal.append(signal)
num_slices.append(1) # single embedding
tokens.extend([tokens_i] * 1)
audio_lengths.extend([slice_length] * 1)
else:
slices = (sig_len - slice_length) // shift + 1
for slice_id in range(slices):
start_idx = slice_id * shift
end_idx = start_idx + slice_length
signal = sig[start_idx:end_idx]
audio_signal.append(signal)

num_slices.append(slices)
tokens.extend([tokens_i] * slices)
audio_lengths.extend([slice_length] * slices)

if has_audio:
audio_signal = torch.stack(audio_signal)
audio_lengths = torch.tensor(audio_lengths)
else:
audio_signal, audio_lengths = None, None
tokens = torch.stack(tokens)
tokens_lengths = torch.tensor(num_slices) # each embedding length

return audio_signal, audio_lengths, tokens, tokens_lengths

def __len__(self):
return len(self.collection)

Expand Down
68 changes: 68 additions & 0 deletions nemo/collections/asr/losses/angularloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ! /usr/bin/python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from nemo.core.classes import Loss, Typing, typecheck
from nemo.core.neural_types import LabelsType, LogitsType, LossType, NeuralType

__all__ = ['AngularSoftmaxLoss']


class AngularSoftmaxLoss(Loss, Typing):
"""
Computes ArcFace Angular softmax angle loss
reference: https://openaccess.thecvf.com/content_CVPR_2019/papers/Deng_ArcFace_Additive_Angular_Margin_Loss_for_Deep_Face_Recognition_CVPR_2019_paper.pdf
args:
scale: scale value for cosine angle
margin: margin value added to cosine angle
"""

@property
def input_types(self):
"""Input types definitions for AnguarLoss.
"""
return {
"logits": NeuralType(('B', 'D'), LogitsType()),
"labels": NeuralType(('B',), LabelsType()),
}

@property
def output_types(self):
"""Output types definitions for AngularLoss.
loss:
NeuralType(None)
"""
return {"loss": NeuralType(elements_type=LossType())}

def __init__(self, scale=20.0, margin=1.35):
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()

self.eps = 1e-7
self.scale = scale
self.margin = margin

@typecheck()
def forward(self, logits, labels):
numerator = self.scale * torch.cos(
torch.acos(torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps))
+ self.margin
)
excl = torch.cat(
[torch.cat((logits[i, :y], logits[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0
)
denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * excl), dim=1)
L = numerator - torch.log(denominator)
return -torch.mean(L)
79 changes: 51 additions & 28 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@
import pickle as pkl
from typing import Dict, Optional, Union

import numpy as np
import torch
from omegaconf import DictConfig
from pytorch_lightning import Trainer

from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataSet
from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss
from nemo.collections.asr.parts.features import WaveformFeaturizer
from nemo.collections.asr.parts.perturb import process_augmentations
from nemo.collections.common.losses import CrossEntropyLoss as CELoss
from nemo.collections.common.metrics import TopKClassificationAccuracy, compute_topk_accuracy
from nemo.core.classes import ModelPT
from nemo.core.classes.common import typecheck
from nemo.core.neural_types import *
from nemo.utils import logging
from nemo.utils.decorators import experimental

__all__ = ['EncDecSpeakerLabelModel', 'ExtractSpeakerEmbeddingsModel']

Expand All @@ -50,7 +50,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor)
self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder)
self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder)
self.loss = CELoss()
if 'angular' in cfg.decoder.params and cfg.decoder.params['angular']:
logging.info("Training with Angular Softmax Loss")
scale = cfg.loss.scale
margin = cfg.loss.margin
self.loss = AngularSoftmaxLoss(scale=scale, margin=margin)
else:
logging.info("Training with Softmax-CrossEntropy loss")
self.loss = CELoss()

self._accuracy = TopKClassificationAccuracy(top_k=[1])

def __setup_dataloader_from_config(self, config: Optional[Dict]):
if 'augmentor' in config:
Expand Down Expand Up @@ -139,38 +148,48 @@ def forward(self, input_signal, input_signal_length):
return logits, embs

# PTL-specific methods
def training_step(self, batch, batch_nb):
def training_step(self, batch, batch_idx):
audio_signal, audio_signal_len, labels, _ = batch
logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
loss_value = self.loss(logits=logits, labels=labels)
labels_hat = torch.argmax(logits, dim=1)
n_correct_pred = torch.sum(labels == labels_hat, dim=0).item()
tensorboard_logs = {'train_loss': loss_value, 'training_batch_acc': (n_correct_pred / len(labels)) * 100}
self.loss_value = self.loss(logits=logits, labels=labels)

tensorboard_logs = {
'train_loss': self.loss_value,
'learning_rate': self._optimizer.param_groups[0]['lr'],
}

return {'loss': loss_value, 'log': tensorboard_logs, "n_correct_pred": n_correct_pred, "n_pred": len(labels)}
correct_counts, total_counts = self._accuracy(logits=logits, labels=labels)

def training_epoch_end(self, outputs):
train_acc = (sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs)) * 100
tensorboard_logs = {'train_acc': train_acc}
for ki in range(correct_counts.shape[-1]):
correct_count = correct_counts[ki]
total_count = total_counts[ki]
top_k = self._accuracy.top_k[ki]
self.accuracy = (correct_count / float(total_count)) * 100

return {'train_acc': train_acc, 'log': tensorboard_logs}
tensorboard_logs['training_batch_accuracy_top@{}'.format(top_k)] = self.accuracy

def validation_step(self, batch, batch_idx):
return {'loss': self.loss_value, 'log': tensorboard_logs}

def validation_step(self, batch, batch_idx, dataloader_idx: int = 0):
audio_signal, audio_signal_len, labels, _ = batch
logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
loss_value = self.loss(logits=logits, labels=labels)
labels_hat = torch.argmax(logits, dim=1)
n_correct_pred = torch.sum(labels == labels_hat, dim=0).item()
self.loss_value = self.loss(logits=logits, labels=labels)
correct_counts, total_counts = self._accuracy(logits=logits, labels=labels)
return {'val_loss': self.loss_value, 'val_correct_counts': correct_counts, 'val_total_counts': total_counts}

return {'val_loss': loss_value, "n_correct_pred": n_correct_pred, "n_pred": len(labels)}
def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
self.val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
correct_counts = torch.stack([x['val_correct_counts'] for x in outputs])
total_counts = torch.stack([x['val_total_counts'] for x in outputs])

def validation_epoch_end(self, outputs):
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
val_acc = (sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs)) * 100
logging.info("validation accuracy {:.3f}".format(val_acc))
tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_acc': val_acc}
topk_scores = compute_topk_accuracy(correct_counts, total_counts)
logging.info("val_loss: {:.3f}".format(self.val_loss_mean))
tensorboard_log = {'val_loss': self.val_loss_mean}
for top_k, score in zip(self._accuracy.top_k, topk_scores):
tensorboard_log['val_epoch_top@{}'.format(top_k)] = score
self.accuracy = score * 100

return {'val_loss': val_loss_mean, 'log': tensorboard_logs}
return {'log': tensorboard_log}

def test_step(self, batch, batch_ix):
audio_signal, audio_signal_len, labels, _ = batch
Expand All @@ -188,16 +207,17 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg=cfg, trainer=trainer)

def test_step(self, batch, batch_ix):
audio_signal, audio_signal_len, labels, _ = batch
audio_signal, audio_signal_len, labels, slices = batch
_, embs = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
return {'embs': embs, 'labels': labels}
return {'embs': embs, 'labels': labels, 'slices': slices}

def test_epoch_end(self, outputs):
embs = torch.cat([x['embs'] for x in outputs])
slices = torch.cat([x['slices'] for x in outputs])
emb_shape = embs.shape[-1]
embs = embs.view(-1, emb_shape).cpu().numpy()
out_embeddings = {}

start_idx = 0
with open(self.test_manifest, 'r') as manifest:
for idx, line in enumerate(manifest.readlines()):
line = line.strip()
Expand All @@ -206,7 +226,10 @@ def test_epoch_end(self, outputs):
uniq_name = '@'.join(structure)
if uniq_name in out_embeddings:
raise KeyError("Embeddings for label {} already present in emb dictionary".format(uniq_name))
out_embeddings[uniq_name] = embs[idx]
num_slices = slices[idx]
end_idx = start_idx + num_slices
out_embeddings[uniq_name] = embs[start_idx:end_idx].mean(axis=0)
start_idx = end_idx

embedding_dir = os.path.join(self.embedding_dir, 'embeddings')
if not os.path.exists(embedding_dir):
Expand Down
Loading