diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 34dd5329bebab2..b72e15f409d84a 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1266,6 +1266,7 @@ def __init__(self, config): self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) + self.ort = config.ort self.init_weights() @@ -1326,7 +1327,10 @@ def forward( masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + if self.ort: + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size).to(torch.float32), labels.view(-1)) + else: + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (prediction_scores,) + outputs[2:]