Skip to content

Commit

Permalink
Fixing bug in Megatron BERT when loss mask is all zeros (#5424)
Browse files Browse the repository at this point in the history
* Fixing bug when loss mask is fully zero

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update megatron_bert_model.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* Update dataset_utils.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataset_utils.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* Update dataset_utils.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

Signed-off-by: Shanmugam Ramasamy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sandeep Subramanian <[email protected]>
  • Loading branch information
3 people authored and tango4j committed Nov 17, 2022
1 parent 785426c commit 6b018da
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ def create_masked_lm_predictions(
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)

num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))
if num_to_predict < 1:
logging.warning(
F'Number of tokens is : {len(tokens)} and mask_probability is {masked_lm_prob}. None of the tokens will be masked'
)

ngrams = np.arange(1, max_ngram_size + 1, dtype=np.int64)
if not geometric_dist:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,13 @@ def loss_func(self, loss_mask, sentence_order, output_tensor):

lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

# Sometimes when the number of tokens is very small, none of the tokens get masked for prediction. In that case loss mask is all zeros
# i.e Happens when the entire batch is masked out (Practically when MBS=1 or 2, and the number of tokens in each batch is < 7 )
if loss_mask.sum() == 0:
lm_loss = torch.sum(lm_loss_.view(-1)) * 0.0
else:
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)
Expand Down

0 comments on commit 6b018da

Please sign in to comment.