Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic compilation in generate: do not rely on inner function #34923

Merged
merged 25 commits into from
Dec 3, 2024

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Nov 25, 2024

What does this PR do?

As discussed, I don't think we should rely on defining an inner function and compiling it for every call to generate.

This moves the compiled forward to PreTrainedModel instead for reuse, which makes the most sense for me. That way, every PreTrainedModel effectively has 2 forwards, and we can dynamically choose between non-compiled (prefill) and compiled (iterative decoding). This is similar to what PyTorch does internally when calling model.compile() inplace, except in their case they force the use of the compiled one after the call was invoked.

Let me know what you think @ArthurZucker

Also address #34906

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, before merging can you make sure:

  • this fixed the PEFT issue! (we are missing a fast test let's add it!)
  • this still has the expected performance gains!
    🤗 LGTM otherwise!

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@Cyrilvallez
Copy link
Member Author

Hey @BenjaminBossan, could you check if this PR solves the issue in peft, or point me to the failing tests please?

@SilverSoldier
Copy link
Contributor

Raised this point in issue #34906 as well, how can we pass custom arguments to compile in this case, say different backend or other parameters?
There should be a way to pass the user arguments for non-default cases.

@ArthurZucker
Copy link
Collaborator

Yeah we can use generate_kwargs for this. The original implem was super minimal for big gains, but down to include that!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rather go with generation_config / generate kwargs that can be passed and used!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @BenjaminBossan, could you check if this PR solves the issue in peft, or point me to the failing tests please?

I checked and this branch indeed resolves the failing tests, thanks!

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
# Only reset it if different from previous config
if getattr(self, "_last_compile_config", CompileConfig()) != compile_config:
self._last_compile_config = compile_config
self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
Copy link
Collaborator

@ydshieh ydshieh Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if it would (probably) make more sense to do

self._compiled_call = None

at this line and delegate the actually compile only in compiled_call (as you already there) - when it is called (like you do in generate).

Having 2 places to call torch.compile seems a slight strange to me (but it's ok though).

And if you decide to take my suggestion, probably the name _set_compile_call should be changed to _set_compile_config

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile does not compile, calling the compile funciton does compile

Copy link
Collaborator

@ydshieh ydshieh Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know. But what I mean here is not the underlying compile happening, but rather the call to torch.compile twice. 2 different methods calling torch.compile seems to me not very good style (for me, it's better if one method is only set config and another one take the job to call torch.compile)

But it's no big deal but just of a habit of making each method does its own job.

Copy link
Member Author

@Cyrilvallez Cyrilvallez Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, but maybe then using _set_compile_call in compiled_call instead? That way, set_compile_call is the only place using torch.compile, and compiled_call is still only used for accessing the fnction? (with drawback call to _set_compile in case it does not already exist)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even with that, in generate we still need to call _set_compile_call then compiled_call right? (as compiled_call) doesn't contain the argument. If this is the case, then it's odd to have compiled_call calling _set_compile_call.

Otherwise, if you think it's ok/good to change compiled_call to accept compile_config arugment, then sound good to me with your suggested change.

Copy link
Collaborator

@ydshieh ydshieh Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, compiled_call is a property so not to take argument .... so the concern in the first paragraph in the above comment is there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

anyway, it's implementation details and not affecting users. Don't take it too serious if the changes will take too much time.

Copy link
Member Author

@Cyrilvallez Cyrilvallez Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought it makes sense to make it a property so that model.compiled_call(**inputs) could always be used directly as an alternative model.forward(**inputs)

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 27, 2024

another comments:

  • (nit) I am not sure it's necessary to save the compile config (unlike our other config classes)
  • I understand we might want to change compile options. But now imagine a use case: I compile with option 1. Then with option 2. But now I want to check with option 1 again for some reason. Question: would compiling with option 1 the second time take the same amount time of compiling with option 1 the first time?

@Cyrilvallez
Copy link
Member Author

Thanks for looking into this as well @ydshieh! Regarding your question, torch is able to cache different graphs for the exact same function, so no, it will actually not re-compile even after the switch. Not sure how many times you can do it before starting to loose cache entries though. But the following:

from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
import torch
import time
device = 3
import warnings
warnings.filterwarnings("ignore")

model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

sequence = "Hey what's the plan"

inputs = tokenizer.encode(sequence, return_tensors='pt').to(device)
model.generation_config.temperature = 1.0
model.generation_config.top_p = 1.0

# Compile default config
t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static")
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}')
print(model._last_compile_config.to_dict())


# Compile new config
t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static", compile_config=CompileConfig(dynamic=True))
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}')
print(model._last_compile_config.to_dict())

# Back to 1st config
t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static")
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}')
print(model._last_compile_config.to_dict())


# Back to 2nd config
t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static", compile_config=CompileConfig(dynamic=True))
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}')
print(model._last_compile_config.to_dict())

actually does not re-compile for the last 2 generate calls, and use already compiled graphs

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, could you confirm performance boost with a gist script shared here(like the one I shared) just to double check? 🤗

@Cyrilvallez
Copy link
Member Author

Confirming that this script returns

>>> Without static cache and compile: 27.079 s
>>> Using `torch.compile`.
>>> Compiling default config: 26.816 s
>>> Using compiled graph: 6.541 s
>>> Compiling new config: 24.327 s
>>> Using compiled new graph: 6.930 s
>>> Back to 1st config and graph: 6.528 s

on a machine with A100 GPU, which is the expected result

@Cyrilvallez Cyrilvallez merged commit ee37bf0 into main Dec 3, 2024
25 checks passed
@Cyrilvallez Cyrilvallez deleted the compile-sample branch December 3, 2024 10:20
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…gingface#34923)

* compiled forward in PreTrainedModel

* update

* style

* update name

* trigger CIs

* Add way to use custom compile args

* style

* switch parameterization to generation_config

* Add to inits

* Update configuration_utils.py

* inits

* style

* docs

* style

* Update configuration_utils.py

* back without dataclass for repo consistency

* Update configuration_utils.py

* style

* style

* style once again

* add config serialization

* update

* true dataclass

* trigger CIs

* merge compile methods + remove serialization of compile config
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants