Skip to content

Commit

Permalink
propagate changes to gemma and cohere
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Apr 24, 2024
1 parent 6c3ef09 commit 872411c
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 33 deletions.
21 changes: 14 additions & 7 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,13 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is
# implemented.
logger.warning_once(
"CohereModel is using CohereSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"CohereModel is using CohereSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
"does not support `output_attentions=True`. Falling back to the manual attention implementation, "
"but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
Expand Down Expand Up @@ -583,15 +586,19 @@ def forward(
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom
# attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash
# Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the
# empty KV entries
if causal_mask is None and cache_position is not None:
key_states = key_states[:, :, : cache_position[-1] + 1, :]
value_states = value_states[:, :, : cache_position[-1] + 1, :]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand Down
21 changes: 14 additions & 7 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,10 +521,13 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is
# implemented.
logger.warning_once(
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does "
"not support `output_attentions=True`. Falling back to the manual attention implementation, "
"but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
Expand Down Expand Up @@ -563,15 +566,19 @@ def forward(
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom
# attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash
# Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the
# empty KV entries
if causal_mask is None and cache_position is not None:
key_states = key_states[:, :, : cache_position[-1] + 1, :]
value_states = value_states[:, :, : cache_position[-1] + 1, :]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand Down
16 changes: 9 additions & 7 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,10 +616,13 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is
# implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does "
"not support `output_attentions=True`. Falling back to the manual attention implementation, "
"but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
Expand Down Expand Up @@ -660,8 +663,7 @@ def forward(
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom
# attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
# attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
Expand All @@ -671,8 +673,8 @@ def forward(
# Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the
# empty KV entries
if causal_mask is None and cache_position is not None:
key_states = key_states[:, :, :cache_position[-1]+1, :]
value_states = value_states[:, :, :cache_position[-1]+1, :]
key_states = key_states[:, :, : cache_position[-1] + 1, :]
value_states = value_states[:, :, : cache_position[-1] + 1, :]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand Down
26 changes: 14 additions & 12 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
import pytest
from parameterized import parameterized

from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed
from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
CaptureLogger,
require_bitsandbytes,
require_flash_attn,
require_read_token,
Expand Down Expand Up @@ -684,17 +683,17 @@ def test_model_13b_greedy_generation(self):
@require_torch_gpu
@require_read_token
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 works as
# intended. See https://github.com/pytorch/pytorch/issues/121943
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096).
EXPECTED_TEXT_COMPLETION = [
'Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial '
'reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory '
'of relativ',
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my '
'fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory "
"of relativ",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my "
"fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]

prompts = [
Expand All @@ -708,19 +707,22 @@ def test_compile_static_cache(self):
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

# Dynamic Cache
# with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)

# Static Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static")
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)

# Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static")
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)

Expand Down

0 comments on commit 872411c

Please sign in to comment.