-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: fix batched generation #29109
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -293,7 +293,7 @@ def test_sink_cache_iterative_prompts(self): | |||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"]) | |||
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): | |||
EXPECTED_GENERATION = [ | |||
"The best color is the one that complements the subject you are photograph", | |||
"The best color is the one that complements the skin tone of the", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changed test results were checked against 4b236aed7618d90546cd2e8797dab5b4a24c5dce
(the commit before the static caches were introduced).
These tests do batched generation, hence the need to change.
👉 the fact that this PR matches the commit before the static caches in this test means that we can now do left-padded batched generation with the same results!
I'll have to run the benchmark on the A100 to make sure everything is alright but otherwise should be good |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, nice catch! I'll approve but let me run the benchmark on my side!
cos = cos.unsqueeze(unsqueeze_dim) | ||
sin = sin.unsqueeze(unsqueeze_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's unsqueeze in the rotary embedding no? or that changes the shape we previously had?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same shapes/no shape problems, but unsqueezing here is preferable by some users (see #27117)
freqs = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @ ( | ||
position_ids[:, None, :].float() | ||
) | ||
freqs = freqs.transpose(1, 2) | ||
emb = torch.cat((freqs, freqs), dim=-1) | ||
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW for BC we could / should still cache the rope no?
With a property _sin_cache: logger.warning_once(will be removed in 4.39) WDYT?
causal_mask = torch.triu(mask, diagonal=1).to(dtype) | ||
causal_mask = torch.triu(mask, diagonal=1) | ||
|
||
causal_mask = causal_mask.to(dtype=dtype, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
@@ -333,18 +333,18 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): | |||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"]) | |||
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): | |||
EXPECTED_GENERATION = [ | |||
"The best color is\n\n\n\n\n\n\n\n\n\n", | |||
"We should not undermind the issues at hand, but address them head on.\nI think", | |||
"The best color isЋ the one that complements the skin tone of", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-isЋ t
+is t
seems strange 😅 but alright
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hehe this weird one is a copy/paste
(it has right-padding, so we should expect weird things at generation time)
Alright, no significant slow downs so 🟢 but I can't do naive Dynamic generation with the same script as before: File "/home/arthur/transformers/../static-kv-cache/clean_bench.py", line 147, in <module>
outputs = model(input_ids, past_key_values=past_key_values,position_ids=position_ids,cache_position=cache_position, return_dict=False, use_cache = True)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 1155, in forward
outputs = self.model(
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 995, in forward
layer_outputs = decoder_layer(
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 721, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 628, in forward
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 107, in forward
position_ids[:, None, :].float()
IndexError: too many indices for tensor of dimension 1 |
@ArthurZucker regarding the benchmark error: position ids should be a 2D tensor, just like the input ids :D I also had to adapt it on my end |
Alright if passing a 1d before was erroring out! |
@gante thanks a lot for this |
self._cos_cached = cos | ||
self._sin_cached = sin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should. not always overwrite them. We need them accessible but not to be overwritten at the forward
What does this PR do?
Fixes batched inference on llama, after the static cache changes were added. For instance,
RUN_SLOW=1 py.test tests/test_cache_utils.py::CacheIntegrationTest::test_dynamic_cache_beam_search
now passes.What was wrong?
position_ids
has shape[bsz, seq_len]
. The line computingfreqs
was correct for batch size = 1, but incorrect for larger batch sizes: it was summing the values for the different batch members. Therefore, we need to create another dimension to prevent this sum from happening, which is what this PR does.Throughput impact of changes
None 🙌 [Measured on my end, RTX3090 +
TinyLlama/TinyLlama-1.1B-Chat-v1.0
]Before this PR
After this PR