-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic compilation in generate: do not rely on inner function #34923
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, before merging can you make sure:
- this fixed the PEFT issue! (we are missing a fast test let's add it!)
- this still has the expected performance gains!
🤗 LGTM otherwise!
Hey @BenjaminBossan, could you check if this PR solves the issue in |
Raised this point in issue #34906 as well, how can we pass custom arguments to compile in this case, say different backend or other parameters? |
Yeah we can use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rather go with generation_config / generate kwargs that can be passed and used!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @BenjaminBossan, could you check if this PR solves the issue in
peft
, or point me to the failing tests please?
I checked and this branch indeed resolves the failing tests, thanks!
# Only reset it if different from previous config | ||
if getattr(self, "_last_compile_config", CompileConfig()) != compile_config: | ||
self._last_compile_config = compile_config | ||
self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if it would (probably) make more sense to do
self._compiled_call = None
at this line and delegate the actually compile only in compiled_call
(as you already there) - when it is called (like you do in generate
).
Having 2 places to call torch.compile
seems a slight strange to me (but it's ok though).
And if you decide to take my suggestion, probably the name _set_compile_call
should be changed to _set_compile_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.compile does not compile, calling the compile funciton does compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know. But what I mean here is not the underlying compile happening, but rather the call to torch.compile
twice. 2 different methods calling torch.compile
seems to me not very good style (for me, it's better if one method is only set config and another one take the job to call torch.compile
)
But it's no big deal but just of a habit of making each method does its own job.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, but maybe then using _set_compile_call
in compiled_call
instead? That way, set_compile_call
is the only place using torch.compile
, and compiled_call
is still only used for accessing the fnction? (with drawback call to _set_compile
in case it does not already exist)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
even with that, in generate
we still need to call _set_compile_call
then compiled_call
right? (as compiled_call
) doesn't contain the argument. If this is the case, then it's odd to have compiled_call
calling _set_compile_call
.
Otherwise, if you think it's ok/good to change compiled_call
to accept compile_config arugment, then sound good to me with your suggested change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, compiled_call
is a property so not to take argument .... so the concern in the first paragraph in the above comment is there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
anyway, it's implementation details and not affecting users. Don't take it too serious if the changes will take too much time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought it makes sense to make it a property so that model.compiled_call(**inputs)
could always be used directly as an alternative model.forward(**inputs)
another comments:
|
Thanks for looking into this as well @ydshieh! Regarding your question, from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
import torch
import time
device = 3
import warnings
warnings.filterwarnings("ignore")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
sequence = "Hey what's the plan"
inputs = tokenizer.encode(sequence, return_tensors='pt').to(device)
model.generation_config.temperature = 1.0
model.generation_config.top_p = 1.0
# Compile default config
t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static")
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}')
print(model._last_compile_config.to_dict())
# Compile new config
t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static", compile_config=CompileConfig(dynamic=True))
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}')
print(model._last_compile_config.to_dict())
# Back to 1st config
t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static")
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}')
print(model._last_compile_config.to_dict())
# Back to 2nd config
t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static", compile_config=CompileConfig(dynamic=True))
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}')
print(model._last_compile_config.to_dict()) actually does not re-compile for the last 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, could you confirm performance boost with a gist script shared here(like the one I shared) just to double check? 🤗
Confirming that this script returns >>> Without static cache and compile: 27.079 s
>>> Using `torch.compile`.
>>> Compiling default config: 26.816 s
>>> Using compiled graph: 6.541 s
>>> Compiling new config: 24.327 s
>>> Using compiled new graph: 6.930 s
>>> Back to 1st config and graph: 6.528 s on a machine with A100 GPU, which is the expected result |
…gingface#34923) * compiled forward in PreTrainedModel * update * style * update name * trigger CIs * Add way to use custom compile args * style * switch parameterization to generation_config * Add to inits * Update configuration_utils.py * inits * style * docs * style * Update configuration_utils.py * back without dataclass for repo consistency * Update configuration_utils.py * style * style * style once again * add config serialization * update * true dataclass * trigger CIs * merge compile methods + remove serialization of compile config
What does this PR do?
As discussed, I don't think we should rely on defining an inner function and compiling it for every call to
generate
.This moves the compiled forward to
PreTrainedModel
instead for reuse, which makes the most sense for me. That way, everyPreTrainedModel
effectively has 2 forwards, and we can dynamically choose between non-compiled (prefill) and compiled (iterative decoding). This is similar to what PyTorch does internally when callingmodel.compile()
inplace, except in their case they force the use of the compiled one after the call was invoked.Let me know what you think @ArthurZucker
Also address #34906