diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3bff8eea50f0c6..9135bb204846ee 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -476,6 +476,7 @@ def _prepare_attention_mask_for_generation( ) can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id attention_mask_from_padding = inputs.ne(pad_token_id).long() + attention_mask = ( attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask ) @@ -1340,7 +1341,10 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa return self._static_cache def _prepare_special_tokens( - self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None + self, + generation_config: GenerationConfig, + kwargs_has_attention_mask: Optional[bool] = None, + device: Optional[Union[torch.device, str]] = None, ): """ Prepares the special tokens for generation, overwriting the generation config with their processed versions @@ -1352,15 +1356,18 @@ def _prepare_special_tokens( """ # Convert special tokens to tensors (if they exist) - def _tensor_or_none(token): + def _tensor_or_none(token, device=None): + if device is None: + device = self.device + if token is None or isinstance(token, torch.Tensor): return token - return torch.tensor(token, device=self.device, dtype=torch.long) + return torch.tensor(token, device=device, dtype=torch.long) - bos_token_id = _tensor_or_none(generation_config.bos_token_id) - eos_token_id = _tensor_or_none(generation_config.eos_token_id) - pad_token_id = _tensor_or_none(generation_config.pad_token_id) - decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id) + bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device) + eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device) + pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device) + decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device) decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). @@ -1511,7 +1518,6 @@ def generate( accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None - self._prepare_special_tokens(generation_config, kwargs_has_attention_mask) # 3. Define model inputs inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( @@ -1519,6 +1525,9 @@ def generate( ) batch_size = inputs_tensor.shape[0] + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + # decoder-only models must use left-padding for batched generation. if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`