Skip to content

Commit

Permalink
fix other cache tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Feb 19, 2024
1 parent ca3dcaf commit 1f14324
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,7 @@ def _update_causal_mask(self, attention_mask, input_tensor):

batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
device = input_tensor.device

# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
Expand All @@ -1053,8 +1054,9 @@ def _update_causal_mask(self, attention_mask, input_tensor):
(self.config.max_position_embeddings, self.config.max_position_embeddings),
fill_value=torch.finfo(dtype).min,
)
causal_mask = torch.triu(mask, diagonal=1).to(dtype)
causal_mask = torch.triu(mask, diagonal=1)

causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_sink_cache_iterative_prompts(self):
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is the one that complements the subject you are photograph",
"The best color is the one that complements the skin tone of the",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
]

Expand Down Expand Up @@ -333,18 +333,18 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is\n\n\n\n\n\n\n\n\n\n",
"We should not undermind the issues at hand, but address them head on.\nI think",
"The best color isЋ the one that complements the skin tone of",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
]

tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
"NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
).to("cuda:1")
).to(torch_device)
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
Expand Down

0 comments on commit 1f14324

Please sign in to comment.