From 33b3cfb7845ad5796ac8e15fe7ceff489de00916 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 19 Feb 2024 13:48:13 +0000 Subject: [PATCH 1/3] batched llama --- src/transformers/models/llama/modeling_llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c30be2a2da4f63..eeac16f543a01d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -103,7 +103,10 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - freqs = (self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ (position_ids.float())).t() + freqs = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @ ( + position_ids[:, None, :].float() + ) + freqs = freqs.transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) @@ -181,6 +184,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed From a64dc2de169f19d03047fc43f45e707ff8d8350c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 19 Feb 2024 14:27:52 +0000 Subject: [PATCH 2/3] fix other cache tests --- src/transformers/models/llama/modeling_llama.py | 4 +++- tests/test_cache_utils.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index eeac16f543a01d..c3bfcaecb62090 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1038,6 +1038,7 @@ def _update_causal_mask(self, attention_mask, input_tensor): batch_size, seq_length = input_tensor.shape[:2] dtype = input_tensor.dtype + device = input_tensor.device # support going beyond cached `max_position_embedding` if seq_length > self.causal_mask.shape[-1]: @@ -1053,8 +1054,9 @@ def _update_causal_mask(self, attention_mask, input_tensor): (self.config.max_position_embeddings, self.config.max_position_embeddings), fill_value=torch.finfo(dtype).min, ) - causal_mask = torch.triu(mask, diagonal=1).to(dtype) + causal_mask = torch.triu(mask, diagonal=1) + causal_mask = causal_mask.to(dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 5f3af2acf5723c..6d31d63e82ef51 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -293,7 +293,7 @@ def test_sink_cache_iterative_prompts(self): @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): EXPECTED_GENERATION = [ - "The best color is the one that complements the subject you are photograph", + "The best color is the one that complements the skin tone of the", "We should not undermind the issues at hand.\nWe should not undermind the issues", ] @@ -333,18 +333,18 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): EXPECTED_GENERATION = [ - "The best color is\n\n\n\n\n\n\n\n\n\n", - "We should not undermind the issues at hand, but address them head on.\nI think", + "The best color isЋ the one that complements the skin tone of", + "We should not undermind the issues at hand.\nWe should not undermind the issues", ] tokenizer = AutoTokenizer.from_pretrained( - "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + "NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="" ) model = AutoModelForCausalLM.from_pretrained( "NousResearch/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, - ).to("cuda:1") + ).to(torch_device) inputs = tokenizer( ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" ).to(model.device) From 9b2c3f7f2aea408b0ee215baff3bf043ebb353ca Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 20 Feb 2024 09:49:58 +0000 Subject: [PATCH 3/3] add bc --- .../models/llama/modeling_llama.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c3bfcaecb62090..9e2efe79d9b3b0 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -101,14 +101,34 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead." + ) + return self._sin_cached + + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead." + ) + return self._cos_cached + def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - freqs = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @ ( - position_ids[:, None, :].float() - ) - freqs = freqs.transpose(1, 2) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) + cos = emb.cos().to(dtype=x.dtype) + sin = emb.sin().to(dtype=x.dtype) + # backwards compatibility + self._cos_cached = cos + self._sin_cached = sin + return cos, sin class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):