From 1f14324997ee904f07c8232136e5f3e6e023fdcf Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 19 Feb 2024 14:27:52 +0000 Subject: [PATCH] fix other cache tests --- src/transformers/models/llama/modeling_llama.py | 4 +++- tests/test_cache_utils.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index eeac16f543a01d..c3bfcaecb62090 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1038,6 +1038,7 @@ def _update_causal_mask(self, attention_mask, input_tensor): batch_size, seq_length = input_tensor.shape[:2] dtype = input_tensor.dtype + device = input_tensor.device # support going beyond cached `max_position_embedding` if seq_length > self.causal_mask.shape[-1]: @@ -1053,8 +1054,9 @@ def _update_causal_mask(self, attention_mask, input_tensor): (self.config.max_position_embeddings, self.config.max_position_embeddings), fill_value=torch.finfo(dtype).min, ) - causal_mask = torch.triu(mask, diagonal=1).to(dtype) + causal_mask = torch.triu(mask, diagonal=1) + causal_mask = causal_mask.to(dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 5f3af2acf5723c..6d31d63e82ef51 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -293,7 +293,7 @@ def test_sink_cache_iterative_prompts(self): @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): EXPECTED_GENERATION = [ - "The best color is the one that complements the subject you are photograph", + "The best color is the one that complements the skin tone of the", "We should not undermind the issues at hand.\nWe should not undermind the issues", ] @@ -333,18 +333,18 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): EXPECTED_GENERATION = [ - "The best color is\n\n\n\n\n\n\n\n\n\n", - "We should not undermind the issues at hand, but address them head on.\nI think", + "The best color isЋ the one that complements the skin tone of", + "We should not undermind the issues at hand.\nWe should not undermind the issues", ] tokenizer = AutoTokenizer.from_pretrained( - "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + "NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="" ) model = AutoModelForCausalLM.from_pretrained( "NousResearch/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, - ).to("cuda:1") + ).to(torch_device) inputs = tokenizer( ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" ).to(model.device)