Skip to content

Commit

Permalink
Merge pull request #11 from microsoft/jingywa/hfbert-changes
Browse files Browse the repository at this point in the history
Bert type cast fix
  • Loading branch information
raviskolli authored May 11, 2021
2 parents ae1411f + 25e7be2 commit 6b9500b
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,7 @@ def __init__(self, config):

self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.ort = config.ort

self.init_weights()

Expand Down Expand Up @@ -1326,7 +1327,10 @@ def forward(
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if self.ort:
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size).to(torch.float32), labels.view(-1))
else:
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down

0 comments on commit 6b9500b

Please sign in to comment.