-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Angular loss1.0 #1101
Angular loss1.0 #1101
Conversation
nithinraok
commented
Sep 1, 2020
•
edited
Loading
edited
- Added angular loss with cosine angle for 1.0
- Fixed multigpu metric issue by reusing classficationtopkaccuracy
- Added support for embedding extraction for speaker diarization
48c3048
to
2c19ad3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM
This pull request fixes 2 alerts when merging c6529f6 into 2ab5b64 - view on LGTM.com fixed alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor changes to pertinent to the model itself, and some major concerns regarding logging callbacks.
slice_length = self.featurizer.sample_rate * self.time_length | ||
_, audio_lengths, _, tokens_lengths = zip(*batch) | ||
slice_length = min(slice_length, max(audio_lengths)) | ||
shift = 1 * 16000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded sample_rate? Replace with featurizer.sample_rate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks missed it, Done
""" | ||
return {"loss": NeuralType(elements_type=LossType())} | ||
|
||
def __init__(self, s=20.0, m=1.35): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No option to override epsilon for other tasks? Add default eps=1e-7.
Also, dont use 1 character names for variables. And add docstring to this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eps is not a parameter, its to avoid negligible division by zero. Yes, I'll add docstring. s and m are very well known short forms in angular loss literature for scale and margin. If it is compulsory I will look.
super().__init__() | ||
|
||
self.eps = 1e-7 | ||
self.s = s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, dont save variables with 1 char names. If its from a paper, add a reference section and explain what this variable is supposed to do (better yet, just use a descriptive name).
self.loss = CELoss() | ||
if 'angular' in cfg.decoder.params and cfg.decoder.params['angular']: | ||
logging.info("Training with Angular Softmax Loss") | ||
s = cfg.loss.s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Config needs to have descriptive names, not one char variable name.
self, | ||
feat_in, | ||
num_classes, | ||
emb_sizes=[1024, 1024], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dont directly use lists here, use None and check for None below and create [1024, 1024] if None. Refer https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
embs.append(emb) | ||
|
||
if self.angular: | ||
for W in self.final.parameters(): | ||
_ = F.normalize(W, p=2, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.normalize
F.normalize is not an inplace op unless you use out=
, so whats the point of this loop then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just normalizes the weights before calculating loss. I missed W =
here, thanks Som
batch_idx + 1, | ||
total_batches, | ||
pl_module.loss_value, | ||
pl_module.accuracy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all modules have accuracy
so this callback will fail for a lot of models. Why not just read the log in its entirety and just print all of the values in the log? Cant we access the log at the end of train_batch_end?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not yet, PTL is working on it
|
||
@rank_zero_only | ||
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): | ||
print_freq = trainer.row_log_interval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will print every single batch (since PTL default is 10, nemo default is 1, not 1.0).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, its provided by user based on what % of num_batches or exact number he/she requires.
) | ||
|
||
def on_validation_epoch_end(self, trainer, pl_module): | ||
logging.info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, not all models have accuracy, so this will crash for them. is there a way to access the log dictionary itself?
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: nithinraok <[email protected]>
c6529f6
to
fdd898d
Compare
This pull request introduces 2 alerts and fixes 3 when merging fdd898d into 292e2fb - view on LGTM.com new alerts:
fixed alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments
@@ -49,12 +39,15 @@ | |||
def main(cfg): | |||
|
|||
logging.info(f'Hydra config: {cfg.pretty()}') | |||
trainer = pl.Trainer(logger=False, checkpoint_callback=False) | |||
if cfg.trainer.gpus > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait do this only during inference (trainer.test()) otherwise you can't use multi GPU training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
spkr_get_emb.py is only run for inference purposes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to go. I'll let @fayejf look it over for comments and then let's merge
@@ -49,12 +39,15 @@ | |||
def main(cfg): | |||
|
|||
logging.info(f'Hydra config: {cfg.pretty()}') | |||
trainer = pl.Trainer(logger=False, checkpoint_callback=False) | |||
if cfg.trainer.gpus > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok.
This pull request introduces 2 alerts and fixes 3 when merging 8e2cd41 into b5ecf8f - view on LGTM.com new alerts:
fixed alerts:
|
This pull request introduces 2 alerts and fixes 3 when merging 8007677 into e9d98c6 - view on LGTM.com new alerts:
fixed alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Just have two minor questions.
hey @nithinraok i want to perform speaker diarialisation providing a audio file and getting multi speaker transcript ( stt of identified speakers) how to do that with this pr? |
You could extract embeddings with this script, and use those frame level embeddings to perform Spectral Clustering by mentioning num_speakers as number of clusters. We don't have this as a to go unified script for now, but all the individual pieces are already there. I will for sure add in next coming weeks with more features. |
Thanks and waiting for these for a long time. @nithinraok |
docs: missing space