diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index c1caae6c6857..b57ce5cb234b 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -790,11 +790,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 955c5ed6839d..23334311ca95 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1433,11 +1433,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 4bc03ade8908..c8aa82d260f8 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -903,6 +903,7 @@ def forward( attentions=all_self_attns, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -911,11 +912,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1136,7 +1132,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index f07b910fcf3e..f5286b1f8857 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1163,11 +1163,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1407,7 +1402,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 271023471720..d1fe960b5d63 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -922,6 +922,7 @@ def forward( attentions=all_self_attns, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -930,11 +931,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index a3fc645e5aea..8be93f9633a8 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1135,11 +1135,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dd053c805fb8..d68788082656 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1044,11 +1044,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1269,7 +1264,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1ff2a6c250e2..f5641264f8e6 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1117,11 +1117,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 61a8a2bf6b60..371b5884994f 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -957,11 +957,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1180,7 +1175,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 885d74426636..8069d8beb507 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -769,11 +769,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -997,7 +992,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index c44545976316..cdd9cee41e42 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1053,11 +1053,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1282,7 +1277,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 000c6f4c0aaf..8e8411cb4497 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1096,11 +1096,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1325,7 +1320,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index c66380ca84f8..2cbd700a199d 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -955,11 +955,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1177,7 +1172,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 2e0c75ed84da..e5d0827a1061 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1128,11 +1128,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1373,7 +1368,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 2f326184e120..cd1e52054fdd 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1046,11 +1046,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1275,7 +1270,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 7cf767b70e04..0ba46d16ff89 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -929,11 +929,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1153,7 +1148,13 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # Prepare a 4D attention mask here so as to have inputs statically shaped when using StaticCache (torch.compile compatibility). + # `flash_attention_2` attention implementation requires a 2D attention mask. + if ( + isinstance(past_key_values, StaticCache) + and attention_mask.ndim == 2 + and self.config._attn_implementation != "flash_attention_2" + ): if inputs_embeds is not None: batch_size, sequence_length = inputs_embeds.shape device = inputs_embeds.device diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8785d5681f73..32b07cce923a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1428,11 +1428,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask