diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 58938b99eeffe5..e8303a79848959 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -301,14 +301,15 @@ def _merge_input_ids_with_image_features( pad_mask = input_ids == self.pad_token_id # expand masks to match embedding dimension - text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim) - pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim) + text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device) + pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device) # insert padding and text token embeddings final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding) final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding) # insert image embeddings - the image mask is always less or equal to the sentence in length final_embedding = final_embedding.masked_scatter( - image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features + image_mask.unsqueeze(-1).expand_as(final_embedding).to(device=final_embedding.device), + scaled_image_features.to(device=final_embedding.device, dtype=final_embedding.dtype), ) final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding) if attention_mask is not None: @@ -329,10 +330,12 @@ def _merge_input_ids_with_image_features( if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) # unmask the prefill causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :] == 0, 0 + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( @@ -484,7 +487,7 @@ def forward( # we use the input attention mask to shift the logits and labels, because it is 2D. shift_attention_mask = input_attention_mask[..., 1:] shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = shift_labels[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() else: shift_logits = shift_logits.contiguous() shift_labels = shift_labels.contiguous()