Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove irrelevant comment in _update_causal_mask #32363

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,11 +790,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,11 +1433,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ def forward(
attentions=all_self_attns,
)

# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
Expand All @@ -911,11 +912,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -1136,7 +1132,13 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
# Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility).
# `flash_attention_2` attention implementation requires a 2D attention mask.
if (
isinstance(past_key_values, StaticCache)
and attention_mask.ndim == 2
and self.config._attn_implementation != "flash_attention_2"
):
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,11 +1163,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -1407,7 +1402,13 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
# Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility).
# `flash_attention_2` attention implementation requires a 2D attention mask.
if (
isinstance(past_key_values, StaticCache)
and attention_mask.ndim == 2
and self.config._attn_implementation != "flash_attention_2"
):
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,7 @@ def forward(
attentions=all_self_attns,
)

# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
Expand All @@ -930,11 +931,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,11 +1135,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,11 +1044,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -1269,7 +1264,13 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
# Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility).
# `flash_attention_2` attention implementation requires a 2D attention mask.
if (
isinstance(past_key_values, StaticCache)
and attention_mask.ndim == 2
and self.config._attn_implementation != "flash_attention_2"
):
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,11 +1117,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,11 +957,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -1180,7 +1175,13 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
# Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility).
# `flash_attention_2` attention implementation requires a 2D attention mask.
if (
isinstance(past_key_values, StaticCache)
and attention_mask.ndim == 2
and self.config._attn_implementation != "flash_attention_2"
):
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,11 +769,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -997,7 +992,13 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
# Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility).
# `flash_attention_2` attention implementation requires a 2D attention mask.
if (
isinstance(past_key_values, StaticCache)
and attention_mask.ndim == 2
and self.config._attn_implementation != "flash_attention_2"
):
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,11 +1053,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -1282,7 +1277,13 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
# Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility).
# `flash_attention_2` attention implementation requires a 2D attention mask.
if (
isinstance(past_key_values, StaticCache)
and attention_mask.ndim == 2
and self.config._attn_implementation != "flash_attention_2"
):
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,11 +1096,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -1325,7 +1320,13 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
# Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility).
# `flash_attention_2` attention implementation requires a 2D attention mask.
if (
isinstance(past_key_values, StaticCache)
and attention_mask.ndim == 2
and self.config._attn_implementation != "flash_attention_2"
):
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,11 +955,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -1177,7 +1172,13 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
# Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility).
# `flash_attention_2` attention implementation requires a 2D attention mask.
if (
isinstance(past_key_values, StaticCache)
and attention_mask.ndim == 2
and self.config._attn_implementation != "flash_attention_2"
):
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,11 +1128,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -1373,7 +1368,13 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
# Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility).
# `flash_attention_2` attention implementation requires a 2D attention mask.
if (
isinstance(past_key_values, StaticCache)
and attention_mask.ndim == 2
and self.config._attn_implementation != "flash_attention_2"
):
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
Expand Down
Loading
Loading