Skip to content

Commit

Permalink
Fix generate with inputs_embeds as input (#32493)
Browse files Browse the repository at this point in the history
* I think inputs_embeds has ndim == 3

* fix sequence length catch

* add generate test

* [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama

* skip whisper

* fix bart test

* more fixes
  • Loading branch information
molbap authored and ArthurZucker committed Aug 20, 2024
1 parent 084fe2e commit 74f57df
Show file tree
Hide file tree
Showing 21 changed files with 176 additions and 109 deletions.
15 changes: 8 additions & 7 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,17 +1132,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,17 +1403,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,8 @@ def prepare_inputs_for_generation(
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

model_inputs.update(
{
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,17 +1143,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
19 changes: 4 additions & 15 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

return causal_mask


Expand Down Expand Up @@ -301,7 +300,6 @@ def forward(
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
Expand Down Expand Up @@ -502,11 +500,9 @@ def forward(

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
Expand All @@ -517,7 +513,6 @@ def forward(
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand Down Expand Up @@ -583,7 +578,6 @@ def forward(
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -996,7 +990,6 @@ def forward(
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand Down Expand Up @@ -1063,7 +1056,6 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
Expand All @@ -1079,22 +1071,20 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)}
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device

batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min

attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
Expand All @@ -1105,7 +1095,6 @@ def prepare_inputs_for_generation(
cache_position=cache_position,
batch_size=batch_size,
)

model_inputs.update(
{
"position_ids": position_ids,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

model_inputs.update(
{
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

model_inputs.update(
{
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,17 +1265,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,17 +1136,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,17 +1176,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,17 +993,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,17 +1278,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,17 +1318,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
Loading

0 comments on commit 74f57df

Please sign in to comment.