-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Self-speculation (Layer-Skip Llama) #34240
Conversation
status: code runs, output is gibberish. Numerical debugging after lunch to figure out what's wrong |
EDIT: Please ignore my comment above. I thought the output of early exit was gibberish but I think it was the output of self-speculative decoding was gibberish. Yes, self-speculative decoding should have same quality as last layer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi. My name is Mostafa and I am one of the main authors of the LayerSkip paper!
Thanks for working on this PR so quickly! I have provided some comments.
Also, for the future, I have some suggestions to consider:
early_exit
arg in generation could be come a callable function for researchers to experiment with dynamic early exit, i.e., a different condition or heurestic to exit for each token (e.g., cosing similarity between a layers input and output above a certain threshold). This is done in papers like CALM.- adapter modules for early exit. Rather than just exiting by jumping to the model's LM head, users may opt to add their own separate LM head or even add their own adapter layers when exiting. This is done in a paper like Kangaroo.
- Different types of self-speculative decoding, e.g.,
- Draft stage uses a subset of KV cache. This is done in MagicDec.
I am happy to discuss online or offline how we can add more features along this direction to enable researchers to unlock a lot of early exit ideas.
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here torch.cat
will only be correct if min(new_positions) == previous_length + 1
? If that's correct, should we also add an assert
statement for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that is correct!
I'm not going to add any check for now, though, and rely on internal tests to detect issues: adding a check here would hurt throughput in the forward pass, and a test can immediately detect issues :)
generation_config: "GenerationConfig", | ||
model_kwargs: Dict, | ||
inputs_tensor: Optional[torch.Tensor] = None, | ||
logits_processor: "LogitsProcessorList" = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, I recently also added stopping_criteria
as well to support integration with Eleuther LM Eval Harness:
facebookresearch/LayerSkip@e38784d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can ignore my comment about supporting StoppingCriteria
. I checked out the PR and integrated with LM Eval Harness and found out that we don't need it.
I think I needed it in my custom implementation, but the native HF implementation doesn't.
@@ -887,7 +887,7 @@ def forward( | |||
all_self_attns = () if output_attentions else None | |||
next_decoder_cache = None | |||
|
|||
for decoder_layer in self.layers: | |||
for decoder_layer in self.layers[: self.num_hidden_layers]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Smart! I like that simple change that enables flexibility.
early_exit_outputs = model.generate(**inputs, early_exit=4, do_sample=False, max_new_tokens=20) | ||
early_exit_decoded = tokenizer.batch_decode(early_exit_outputs, skip_special_tokens=True) | ||
self.assertEqual(early_exit_decoded, [expected_output]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest adding an assertion check to ensure the output of a model prunted to early_exit
layers has the identical output as the same model with early_exit
arg in generation
# Remove layers manually | |
model = model.model.layers[:4] | |
del model.model.layers[4:] | |
model.num_hidden_layers = 4 | |
manual_early_exit_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20) | |
manual_early_exit_decoded = tokenizer.batch_decode(manual_early_exit_outputs, skip_special_tokens=True) | |
self.assertEqual(early_exit_decoded, manual_early_exit_decoded) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might have misunderstood the code, does model.generate(**inputs, early_exit=4, do_sample=False, max_new_tokens=20)
perform static early exit, or does it perform self-speculative early-exit decoding?
Personally, I would suggest to separate them some how:
- Static early exit:
model.generate(**inputs, early_exit=4)
- Self-speculative decoding, early exit:
model.generate(**inputs, assisstant_model={"early_exit": 4})
or something like that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface is indeed confusing -- the demo above was meant to run self-speculative early-exit decoding.
I see two options:
model.generate(**inputs, assistant_early_exit=4)
-- make the name of the argument more precisemodel.generate(**inputs, assistant_model=model, early_exit=4)
-- withassistant_model
set we know we are doing speculative decoding, so the use ofearly_exit
becomes more self-evident.
I was thinking of going with option 2, since we could then do model.generate(**inputs, early_exit=4)
to run static early exit. WDYT?
(btw, in the long run, we will mode ALL assisted generation/speculative decoding args into a assistant_kwags
dictionary, otherwise things will get messy soon)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am worried that option 2. might be confusing because assistant_model=model, early_exit=4
could potentially be the main model exiting at layer 4 and draft model be the main model.
On the short-term, I feel option 1 is better as assistant_early_exit=4
clearly means we exit early for speculative decoding.
On the long-term, assistant_kwags
is a great idea, or even having something likeassisstant_model=SelfSpeculative(early_exit=4)
Hi Mostafa (@mostafaelhoushi) 👋 Glad to see you here! My utmost goal for this PR is to get Layer Skip to the hands of our users with a) good throughput numbers b) a simple interface. Self-speculative decoding is indeed the best of both worlds for low batch sizes 💪 I appreciate the extra suggestions, but they add significant complexity -- e.g. if we accept callable for self-speculative decoding, we might want to apply the callable in different positions. Keeping things somewhat simple means others can just fork what we have and implement their idea quickly on top! It also makes our maintenance job doable 🤗 Naturally, if a given technique shows clear advantages and can be applied on pre-trained weights without too much complexity, like layer skip, we'll jump straight to implementation. (For instance, a few years ago we implemented a complex constrained decoding method, before json generation became popular. However, because the implementation was complex and it was somewhat niche, it quickly became unmaintained -- we got the additional code bloat with no relevant benefits) Sorry to be a turn off -- I really appreciate the ideas coming in! |
Update: I've removed cache sharing between the early exit and the full model -- it would require significant changes in Tomorrow I'll work on updating the interface as per comments above, and update tests. |
Thanks @gante ! Everything you said makes sense and I agree that this is the wise way forward.
OK. Sounds good. According to our experiments, cache sharing leads to an additional 10% speedup, but without cache sharing we should still get significant speedup. |
Hi everyone. In my opinion, the PR is ready to go, except for one feedback. My only feedback is to modify the name of the argument from
Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added my final suggested edits to rename early_exit
argument to assistant_early_exit
Co-authored-by: Mostafa Elhoushi <[email protected]>
@@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer, | |||
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs. | |||
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation). | |||
|
|||
#### Universal Assisted Decoding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nesting was not fully right -- normal "speculative decoding" examples were under "Universal Assisted Decoding". Moved a few things around)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(double-checked the changes, most of them done by me, LGTM :D)
@ArthurZucker @mostafaelhoushi have a look at the state of the PR and, if we all agree, let's merge 🤗 |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! 🔥
""" | ||
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates | ||
candidates through the use of **the model itself**, exiting early. Can only be used with models that support early | ||
exit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe mention a specific model as an example here? (there aren't many models that currently support it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding an example of a model that supports early exit as per the suggestion of @pcuenca . Not sure if it is a good idea to add link to a model collection in the docstring but feel free to remove it.
exit. | |
exit, e.g., `facebook/layerskip-llama3.2-1B` or any of the models listed in this [collection](https://huggingface.co/collections/facebook/layerskip-666b25c50c8ae90e1965727a). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a single model would be enough for me, a collection could give the impression that we are maintaining a list of compatible models there, which is not the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(added a single model :) )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @gante ! All my feedback is addressed! Approving the PR from my side.
Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a smart trick! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind! It's annoying to have to monkey patch all models, but fine in this case as it is strictly equivalent.
I think we have access to the config at this point because of
super().__init__(config)
, so would rather we use config.num_hidden_layers
directly!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done -- we now use self.config.num_hidden_layers
, as opposed to using a new attribute (self.num_hidden_layers
) ✅
* 😅 * early exit (huggingface#34244) * mvp * docs and tests * a few fixes * no shared cache * Apply suggestions from code review Co-authored-by: Mostafa Elhoushi <[email protected]> * docs * make fix-copies * cohere fix * [test all] * [test all] consistent model code copies * [test all] make fix-copies :D * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> * Update src/transformers/generation/candidate_generator.py * Update src/transformers/generation/configuration_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * [test all] don't use a stand-alone attribute; fix test --------- Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
* 😅 * early exit (huggingface#34244) * mvp * docs and tests * a few fixes * no shared cache * Apply suggestions from code review Co-authored-by: Mostafa Elhoushi <[email protected]> * docs * make fix-copies * cohere fix * [test all] * [test all] consistent model code copies * [test all] make fix-copies :D * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> * Update src/transformers/generation/candidate_generator.py * Update src/transformers/generation/configuration_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * [test all] don't use a stand-alone attribute; fix test --------- Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
What does this PR do?
Adds self-speculation -- support for Meta Llama 3.2 Layer-Skip model
Test script: