Skip to content

Commit

Permalink
Tests: move generate tests to the right mixin and delete redundant …
Browse files Browse the repository at this point in the history
…tests (huggingface#34464)

* tmp commit

* tmp commit

* cull overwrites of deleted tests

* typo

* more specific docstring

* make fixup

* parameterize at the top?

* correction

* more deletions :D

* tmp commit

* for VLMs too

* fix _check_outputs

* test nit

* make fixup

* fix another flaky

* test_generate_from_inputs_embeds -- handle missing attention mask
  • Loading branch information
gante authored and BernardZach committed Dec 5, 2024
1 parent 9f68b9d commit f60388c
Show file tree
Hide file tree
Showing 46 changed files with 263 additions and 2,346 deletions.
33 changes: 22 additions & 11 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,14 @@ def prepare_inputs_for_generation(
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
Expand Down Expand Up @@ -414,7 +418,7 @@ def prepare_inputs_for_generation(
for model_input_name in ["position_ids", "token_type_ids"]:
model_input = kwargs.get(model_input_name)
if model_input is not None:
if past_key_values:
if past_key_values is not None:
model_input = model_input[:, -input_ids.shape[1] :]
model_input = model_input.clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input
Expand Down Expand Up @@ -568,27 +572,34 @@ def _maybe_initialize_input_ids_for_generation(

def _prepare_attention_mask_for_generation(
self,
inputs: torch.Tensor,
pad_token_id: Optional[torch.Tensor],
eos_token_id: Optional[torch.Tensor],
inputs_tensor: torch.Tensor,
generation_config: GenerationConfig,
model_kwargs: Dict[str, Any],
) -> torch.LongTensor:
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor

# `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
inputs_tensor = model_kwargs["input_ids"]

# No information for attention mask inference -> return default attention mask
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
if pad_token_id is None:
return default_attention_mask

is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if not is_input_ids:
return default_attention_mask

is_pad_token_in_inputs = (pad_token_id is not None) and (
isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any()
isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()

attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
Expand Down Expand Up @@ -2020,7 +2031,7 @@ def generate(

if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
inputs_tensor, generation_config, model_kwargs
)
elif kwargs_has_attention_mask:
# TODO (joao): generalize this check with other types of inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,8 @@ def forward(

if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)

legacy_processing = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ def forward(

if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)

legacy_processing = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,8 @@ def forward(

if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values/pixel_values_videos and inputs_embeds at the same time, and must specify either one"
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)

if inputs_embeds is None:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ def generate(

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
input_ids, generation_config, model_kwargs
)

# 5. Prepare `max_length` depending on other stopping criteria.
Expand Down Expand Up @@ -2578,7 +2578,7 @@ def generate(

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
inputs_tensor, generation_config, model_kwargs
)

if "encoder_outputs" not in model_kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,7 @@ def generate(

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
input_ids, generation_config, model_kwargs
)

# 5. Prepare `max_length` depending on other stopping criteria.
Expand Down Expand Up @@ -2425,7 +2425,7 @@ def generate(

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
inputs_tensor, generation_config, model_kwargs
)

if "encoder_hidden_states" not in model_kwargs:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ def forward(

if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
"You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
"time, and must specify either one"
)

legacy_processing = False
Expand Down
Loading

0 comments on commit f60388c

Please sign in to comment.