Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Self-speculation (Layer-Skip Llama) #34240

Merged
merged 16 commits into from
Nov 19, 2024
Merged

Self-speculation (Layer-Skip Llama) #34240

merged 16 commits into from
Nov 19, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 18, 2024

What does this PR do?

Adds self-speculation -- support for Meta Llama 3.2 Layer-Skip model


Test script:

from transformers import AutoTokenizer, AutoModelForCausalLM
import time

expected_output = [""]

prompt = "Alice and Bob"
checkpoint = "facebook/layerskip-llama3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

model = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda")
original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)  #warmup

start = time.time()
original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)
end = time.time()
print(f"Original: {end-start}")
print(f"Output text", tokenizer.batch_decode(original_outputs, skip_special_tokens=True))

start = time.time()
early_exit_outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20)
end = time.time()
print(f"Early Exit: {end-start}")
print(f"Early Exit text", tokenizer.batch_decode(early_exit_outputs, skip_special_tokens=True))

ArthurZucker and others added 2 commits October 18, 2024 11:31
* mvp

* docs and tests
@gante
Copy link
Member

gante commented Oct 21, 2024

status: code runs, output is gibberish. Numerical debugging after lunch to figure out what's wrong

@ArthurZucker ArthurZucker changed the title Draft Draft layer skip addition Oct 21, 2024
@mostafaelhoushi
Copy link
Contributor

mostafaelhoushi commented Oct 22, 2024

status: code runs, output is gibberish. Numerical debugging after lunch to figure out what's wrong

Can you show a sample of how the output looks like? It is expected to be of lower quality but interested to see how gibberish it would be.

Also, maybe try a later layer like layer 14 or 13 to see if it's still gibberish?

EDIT: Please ignore my comment above. I thought the output of early exit was gibberish but I think it was the output of self-speculative decoding was gibberish. Yes, self-speculative decoding should have same quality as last layer.

Copy link
Contributor

@mostafaelhoushi mostafaelhoushi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. My name is Mostafa and I am one of the main authors of the LayerSkip paper!
Thanks for working on this PR so quickly! I have provided some comments.

Also, for the future, I have some suggestions to consider:

  • early_exit arg in generation could be come a callable function for researchers to experiment with dynamic early exit, i.e., a different condition or heurestic to exit for each token (e.g., cosing similarity between a layers input and output above a certain threshold). This is done in papers like CALM.
  • adapter modules for early exit. Rather than just exiting by jumping to the model's LM head, users may opt to add their own separate LM head or even add their own adapter layers when exiting. This is done in a paper like Kangaroo.
  • Different types of self-speculative decoding, e.g.,
    • Draft stage uses a subset of KV cache. This is done in MagicDec.

I am happy to discuss online or offline how we can add more features along this direction to enable researchers to unlock a lot of early exit ideas.

docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
Comment on lines +453 to +454
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here torch.cat will only be correct if min(new_positions) == previous_length + 1? If that's correct, should we also add an assert statement for that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that is correct!

I'm not going to add any check for now, though, and rely on internal tests to detect issues: adding a check here would hurt throughput in the forward pass, and a test can immediately detect issues :)

generation_config: "GenerationConfig",
model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I recently also added stopping_criteria as well to support integration with Eleuther LM Eval Harness:
facebookresearch/LayerSkip@e38784d

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can ignore my comment about supporting StoppingCriteria. I checked out the PR and integrated with LM Eval Harness and found out that we don't need it.
I think I needed it in my custom implementation, but the native HF implementation doesn't.

@@ -887,7 +887,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.num_hidden_layers]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart! I like that simple change that enables flexibility.

early_exit_outputs = model.generate(**inputs, early_exit=4, do_sample=False, max_new_tokens=20)
early_exit_decoded = tokenizer.batch_decode(early_exit_outputs, skip_special_tokens=True)
self.assertEqual(early_exit_decoded, [expected_output])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding an assertion check to ensure the output of a model prunted to early_exit layers has the identical output as the same model with early_exit arg in generation

Suggested change
# Remove layers manually
model = model.model.layers[:4]
del model.model.layers[4:]
model.num_hidden_layers = 4
manual_early_exit_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)
manual_early_exit_decoded = tokenizer.batch_decode(manual_early_exit_outputs, skip_special_tokens=True)
self.assertEqual(early_exit_decoded, manual_early_exit_decoded)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have misunderstood the code, does model.generate(**inputs, early_exit=4, do_sample=False, max_new_tokens=20) perform static early exit, or does it perform self-speculative early-exit decoding?

Personally, I would suggest to separate them some how:

  • Static early exit: model.generate(**inputs, early_exit=4)
  • Self-speculative decoding, early exit: model.generate(**inputs, assisstant_model={"early_exit": 4}) or something like that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface is indeed confusing -- the demo above was meant to run self-speculative early-exit decoding.

I see two options:

  1. model.generate(**inputs, assistant_early_exit=4) -- make the name of the argument more precise
  2. model.generate(**inputs, assistant_model=model, early_exit=4) -- with assistant_model set we know we are doing speculative decoding, so the use of early_exit becomes more self-evident.

I was thinking of going with option 2, since we could then do model.generate(**inputs, early_exit=4) to run static early exit. WDYT?

(btw, in the long run, we will mode ALL assisted generation/speculative decoding args into a assistant_kwags dictionary, otherwise things will get messy soon)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am worried that option 2. might be confusing because assistant_model=model, early_exit=4 could potentially be the main model exiting at layer 4 and draft model be the main model.

On the short-term, I feel option 1 is better as assistant_early_exit=4 clearly means we exit early for speculative decoding.

On the long-term, assistant_kwags is a great idea, or even having something likeassisstant_model=SelfSpeculative(early_exit=4)

@gante
Copy link
Member

gante commented Oct 22, 2024

Hi Mostafa (@mostafaelhoushi) 👋 Glad to see you here!

My utmost goal for this PR is to get Layer Skip to the hands of our users with a) good throughput numbers b) a simple interface. Self-speculative decoding is indeed the best of both worlds for low batch sizes 💪

I appreciate the extra suggestions, but they add significant complexity -- e.g. if we accept callable for self-speculative decoding, we might want to apply the callable in different positions. Keeping things somewhat simple means others can just fork what we have and implement their idea quickly on top! It also makes our maintenance job doable 🤗 Naturally, if a given technique shows clear advantages and can be applied on pre-trained weights without too much complexity, like layer skip, we'll jump straight to implementation.

(For instance, a few years ago we implemented a complex constrained decoding method, before json generation became popular. However, because the implementation was complex and it was somewhat niche, it quickly became unmaintained -- we got the additional code bloat with no relevant benefits)

Sorry to be a turn off -- I really appreciate the ideas coming in!

@gante
Copy link
Member

gante commented Oct 22, 2024

Update: I've removed cache sharing between the early exit and the full model -- it would require significant changes in forward to skip the computation of the cached values, which we don't want to commit for now. The outputs are now correct 👍

Tomorrow I'll work on updating the interface as per comments above, and update tests.

@mostafaelhoushi
Copy link
Contributor

Hi Mostafa (@mostafaelhoushi) 👋 Glad to see you here!

My utmost goal for this PR is to get Layer Skip to the hands of our users with a) good throughput numbers b) a simple interface. Self-speculative decoding is indeed the best of both worlds for low batch sizes 💪

I appreciate the extra suggestions, but they add significant complexity -- e.g. if we accept callable for self-speculative decoding, we might want to apply the callable in different positions. Keeping things somewhat simple means others can just fork what we have and implement their idea quickly on top! It also makes our maintenance job doable 🤗 Naturally, if a given technique shows clear advantages and can be applied on pre-trained weights without too much complexity, like layer skip, we'll jump straight to implementation.

(For instance, a few years ago we implemented a complex constrained decoding method, before json generation became popular. However, because the implementation was complex and it was somewhat niche, it quickly became unmaintained -- we got the additional code bloat with no relevant benefits)

Sorry to be a turn off -- I really appreciate the ideas coming in!

Thanks @gante ! Everything you said makes sense and I agree that this is the wise way forward.

Update: I've removed cache sharing between the early exit and the full model -- it would require significant changes in forward to skip the computation of the cached values, which we don't want to commit for now. The outputs are now correct 👍

Tomorrow I'll work on updating the interface as per comments above, and update tests.

OK. Sounds good. According to our experiments, cache sharing leads to an additional 10% speedup, but without cache sharing we should still get significant speedup.

@mostafaelhoushi
Copy link
Contributor

Hi everyone. In my opinion, the PR is ready to go, except for one feedback. My only feedback is to modify the name of the argument from early_exit to assistant_early_exit for a couple of reasons:

  • to be consistent with the assistant_model argument for speculative decoding. Hence assistant_model is when the draft stage corresponds to a different model, and assistant_early_exit is when the draft stage is an exit at an earlier layer.
  • to avoid confusion with the idea of early exit of autoregressive decoding that the feature doesn't implement.

Thanks!

Copy link
Contributor

@mostafaelhoushi mostafaelhoushi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added my final suggested edits to rename early_exit argument to assistant_early_exit

docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
src/transformers/generation/candidate_generator.py Outdated Show resolved Hide resolved
src/transformers/generation/candidate_generator.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Outdated Show resolved Hide resolved
@gante gante marked this pull request as ready for review November 18, 2024 17:22
@@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer,
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).

#### Universal Assisted Decoding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nesting was not fully right -- normal "speculative decoding" examples were under "Universal Assisted Decoding". Moved a few things around)

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(double-checked the changes, most of them done by me, LGTM :D)

@gante
Copy link
Member

gante commented Nov 18, 2024

@ArthurZucker @mostafaelhoushi have a look at the state of the PR and, if we all agree, let's merge 🤗

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! 🔥

docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
"""
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
candidates through the use of **the model itself**, exiting early. Can only be used with models that support early
exit.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention a specific model as an example here? (there aren't many models that currently support it)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding an example of a model that supports early exit as per the suggestion of @pcuenca . Not sure if it is a good idea to add link to a model collection in the docstring but feel free to remove it.

Suggested change
exit.
exit, e.g., `facebook/layerskip-llama3.2-1B` or any of the models listed in this [collection](https://huggingface.co/collections/facebook/layerskip-666b25c50c8ae90e1965727a).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a single model would be enough for me, a collection could give the impression that we are maintaining a list of compatible models there, which is not the case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(added a single model :) )

src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@mostafaelhoushi mostafaelhoushi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @gante ! All my feedback is addressed! Approving the PR from my side.

Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
@gante gante changed the title Draft layer skip addition Self-speculation (Layer-Skip Llama) Nov 19, 2024
Copy link
Collaborator Author

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a smart trick! 🤗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind! It's annoying to have to monkey patch all models, but fine in this case as it is strictly equivalent.
I think we have access to the config at this point because of
super().__init__(config), so would rather we use config.num_hidden_layers directly!

Copy link
Member

@gante gante Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done -- we now use self.config.num_hidden_layers, as opposed to using a new attribute (self.num_hidden_layers) ✅

@gante gante merged commit 54739a3 into main Nov 19, 2024
27 checks passed
@gante gante deleted the layer-skip branch November 19, 2024 12:20
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* 😅

* early exit (huggingface#34244)

* mvp

* docs and tests

* a few fixes

* no shared cache

* Apply suggestions from code review

Co-authored-by: Mostafa Elhoushi <[email protected]>

* docs

* make fix-copies

* cohere fix

* [test all]

* [test all] consistent model code copies

* [test all] make fix-copies :D

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>

* Update src/transformers/generation/candidate_generator.py

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Pedro Cuenca <[email protected]>

* [test all] don't use a stand-alone attribute; fix test

---------

Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* 😅

* early exit (huggingface#34244)

* mvp

* docs and tests

* a few fixes

* no shared cache

* Apply suggestions from code review

Co-authored-by: Mostafa Elhoushi <[email protected]>

* docs

* make fix-copies

* cohere fix

* [test all]

* [test all] consistent model code copies

* [test all] make fix-copies :D

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>

* Update src/transformers/generation/candidate_generator.py

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Pedro Cuenca <[email protected]>

* [test all] don't use a stand-alone attribute; fix test

---------

Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants