diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 53e920e3650a..9311c45e6514 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -769,6 +769,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): config_class = PegasusXConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"] def _init_weights(self, module): std = self.config.init_std @@ -1299,6 +1300,8 @@ def forward( # embed positions positions = self.embed_positions(inputs_embeds, past_key_values_length) + positions = positions.to(inputs_embeds.device) + hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)