From 6b018da4b80f55f2f955495e33fa7bb2ba1d7c5e Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Date: Wed, 16 Nov 2022 14:14:52 -0800 Subject: [PATCH] Fixing bug in Megatron BERT when loss mask is all zeros (#5424) * Fixing bug when loss mask is fully zero Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update megatron_bert_model.py Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> * Update dataset_utils.py Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataset_utils.py Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> * Update dataset_utils.py Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sandeep Subramanian --- .../nlp/data/language_modeling/megatron/dataset_utils.py | 4 ++++ .../nlp/models/language_modeling/megatron_bert_model.py | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index bd071cf3f05e..75cea0bca417 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -234,6 +234,10 @@ def create_masked_lm_predictions( return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) + if num_to_predict < 1: + logging.warning( + F'Number of tokens is : {len(tokens)} and mask_probability is {masked_lm_prob}. None of the tokens will be masked' + ) ngrams = np.arange(1, max_ngram_size + 1, dtype=np.int64) if not geometric_dist: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index 0a850289301f..20da2a38f7ce 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -358,7 +358,13 @@ def loss_func(self, loss_mask, sentence_order, output_tensor): lm_loss_ = lm_loss_.float() loss_mask = loss_mask.float() - lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() + + # Sometimes when the number of tokens is very small, none of the tokens get masked for prediction. In that case loss mask is all zeros + # i.e Happens when the entire batch is masked out (Practically when MBS=1 or 2, and the number of tokens in each batch is < 7 ) + if loss_mask.sum() == 0: + lm_loss = torch.sum(lm_loss_.view(-1)) * 0.0 + else: + lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() if sop_logits is not None: sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)