-
Notifications
You must be signed in to change notification settings - Fork 479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add static cache #89
Add static cache #89
Conversation
…ttention Pef/flash sdpa attention
…ing arguments for attn implementation
Hi @eustlb, Thanks for your great works, I'm trying to pre-produce the result but somehow the speed is very slow on A100 80Gb maybe due to graph recompile. Here is my full steps to pre-produce:
import soundfile as sf
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import time
import torch
import torch._dynamo.config
import torch._inductor.config
class Timer:
def __init__(self, name):
self.name = name
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
def __enter__(self):
torch.cuda.synchronize()
self.start_event.record()
self.start = time.perf_counter()
def __exit__(self, exc_type, exc_value, traceback):
self.end_event.record()
torch.cuda.synchronize()
elapsed_time = self.start_event.elapsed_time(self.end_event) * 1.0e-3
print('Execution time:', elapsed_time, 'seconds')
...
with Timer("First run"):
generation = model.generate(
input_ids=input_ids,
prompt_input_ids=prompt_input_ids,
attention_mask=attention_mask,
prompt_attention_mask=prompt_attention_mask,
).to(torch.float32)
...
with Timer("Second run"):
generation = model.generate(
input_ids=input_ids,
prompt_input_ids=prompt_input_ids,
attention_mask=attention_mask,
prompt_attention_mask=prompt_attention_mask,
).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("./output_2.wav", audio_arr, model.config.sampling_rate)
Logs:
Can you help me to take a look if there is any problem in my env setup? |
Hey @sang-nguyen-ts, Thanks a lot for reporting this point. Looked into it and you might notice that your provided code gives expected run time results when setting Good point is that for benchmarking (code here), I luckily used three warmup steps without knowing this point, explaining why I did not observed this behavior before. |
Wow, great job! I haven't looked into the code yet, but does it include the port to Transformers 4.42? PS. Ah, it requires even newer Transformers, probably due to EncoderDecoderCache usage |
I've check the benchmark code and I found it using the same prompt for every run, I tried a test case when prompt is difference over runs then it cause graph recompile: prompt_1 = "A paragraph is defined as “a group of sentences or a single sentence that forms a unit” (Lunsford and Connors 116). Length and appearance do not determine whether a section in a paper is a paragraph."
prompt_2 = "Hey, how are you doing today?"
prompt_3 = "Hey, how are you doing mate?"
prompt_4 = "Hey, how are you doing my friend?" logs:
|
@sang-nguyen-ts compiling the forward pass does require every input to have a constant shape. That is why you'll find in the benchmarking code that we pad the tokenized description and prompt. The first prompt of your examples is 56 tokens long, explaining why the model is recompiled for the second prompt that is padded to 50 tokens. Consider either setting |
Very cool use of the static caches 💛 Possibly useful information for follow-up work on this repo -- on
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @eustlb, congrats for the great PR! LGTM!
I've left a few questions, to make sure everything's under control and to satisfy my curiosity.
A last question on my side is how can I use the model to get the previous behaviour (i.e without the Cache) ?
Also, note that I have yet to try if it's still compatible with training, will do right now
parler_tts/modeling_parler_tts.py
Outdated
@@ -244,7 +248,7 @@ def get_embedding(num_embeddings: int, embedding_dim: int): | |||
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): | |||
bsz, seq_len, _ = input_ids.size() | |||
# Create the position ids from the input token ids. | |||
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device) | |||
position_ids = (torch.arange(seq_len, device=input_ids.device) + past_key_values_length).to(input_ids.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we cast the torch arange
to input_ids.device
here ? Is it because past_key_values_length
is a tensor now ?
If so, can you update the type in the signature?
And also, if so, do we have to cast twice to the input_ids device ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch.
Issue here is that there is an inconsistency in past_key_values_length
type: can be either 0, so an int, and a tensor (code taken from here). This inconsistency is well handled by torch when doing torch.arange(seq_len) + past_key_values_length
, yet when past_key_values_length
is a tensor, not on cpu device, we need to make sure that torch.arange
creates tensor on the same device. I did not use past_key_values_length.device
to avoid the case where past_key_values_length
is int
0, and used this trick instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree that this is not very elegant, yet I don't see a more elegant way to do it. Maybe change the logic copied from whisper's code that introduces this inconsistency ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.arange(seq_len, device=input_ids.device) + past_key_values_length
should be enough then, right ?
attention_mask: Optional[torch.Tensor] = None, | ||
cos: Optional[torch.LongTensor] = None, | ||
sin: Optional[torch.LongTensor] = None, | ||
layer_head_mask: Optional[torch.Tensor] = None, | ||
output_attentions: bool = False, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
"""Input shape: Batch x Time x Channel""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great refactoring!
Besides the Cache related changes, what motivated the changes in shape? Did you test speed-ups ? Or just modify to stich with Whisper's implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And you previously told me about small difference between the (static KV cache + compile) vs the previous implementation, are you sure there not coming from the difference in shape you've diffused here ?
E.g you do attn_output = torch.matmul(attn_probs, value_states)
, with 4D tensors. We previously did attn_output = torch.bmm(attn_probs, value_states)
with 3D tensors, which might results in small difference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified to stich with Whisper's implementation that is allegedly faster (see whisper's static cache PR, yet I have not benchmarked it myself.
Concerning the change from torch.bmm
to torch.matmul
, the tests I've run showed exact same results for every dtype.
@@ -1948,10 +2030,11 @@ def generate( | |||
inputs, generation_config.bos_token_id, model_kwargs | |||
) | |||
batch_size = input_ids.shape[0] // self.num_codebooks | |||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even though we don't ever use ParlerTTSForCausalLM
, have you tested it still ? No worries if not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not tested at all, I have not propagated the changes to ParlerTTSForCausalLM.
prompt_hidden_states = model_kwargs["prompt_hidden_states"] | ||
num_codebooks = self.decoder.num_codebooks | ||
input = decoder_input_ids.reshape(-1, num_codebooks, decoder_input_ids.shape[-1]) | ||
inputs_embeds = sum( | ||
[ | ||
self.decoder.model.decoder.embed_tokens[codebook](input[:, codebook]) | ||
for codebook in range(num_codebooks) | ||
] | ||
) | ||
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) | ||
model_kwargs["inputs_embeds"] = inputs_embeds | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure to follow ! You prepend here, which makes total sense, but still leaves prompt_hidden_states
in model_kwargs, so isn't it supposed to be used again in the forward pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prepend here to make sure that we have "inputs_embeds"
in the model_kwargs when _get_initial_cache_position
is called. I have not removed prompt_hidden_states
from the model_kwargs
here since I wanted to change as little as possible things from the current logic. Indeed, in the current implementation of prepare_inputs_for_generation
, decoder_inputs_embeds
(that corresponds to the inputs_embeds here created with prepending) is not handled and I have not changed that. Were we to remove promp_hidden_states
from the model_kwargs
at this stage, we need to then handle decoder_inputs_embeds
in prepare_inputs_for_generation
, which is doable, yet I felt the impact negligible and preferred to aim for a version that would change as little as possible things from the current logic.
Another question on my side! |
Co-authored-by: Yoach Lacombe <[email protected]>
Co-authored-by: Yoach Lacombe <[email protected]>
I had the same intuition about it but it seems that a call to the full generate is necessary to warmup compilation, likely a specific torch compile inner working. Notice that the same warmup is done for other Transformers' models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left some final remarks as to training compatibility, thanks for the great work!
For the final changes, it'd be great that you check if it doesn't break compile compatibility and run your standard tests to make sure everything's okay!
parler_tts/modeling_parler_tts.py
Outdated
@@ -244,7 +248,7 @@ def get_embedding(num_embeddings: int, embedding_dim: int): | |||
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): | |||
bsz, seq_len, _ = input_ids.size() | |||
# Create the position ids from the input token ids. | |||
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device) | |||
position_ids = (torch.arange(seq_len, device=input_ids.device) + past_key_values_length).to(input_ids.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.arange(seq_len, device=input_ids.device) + past_key_values_length
should be enough then, right ?
if isinstance(past_key_value, StaticCache): | ||
raise ValueError( | ||
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " | ||
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's focus on the current behaviour, if we need to add it, we'll do it in a subsequent PR
Thanks @eustlb, next steps:
Merging now ;) |
This PR enables compilation of the forward path of Parler-TTS.
Dynamic caching of keys and values during the auto-regressive decoding makes tensors in
past_key_values
of changing shape and stride, causing recompilation at each pass of the forward method. In this PR, we implement a keys and values static cache, enabling a single compilation of the model when generating. Work done here is inspired from this PR already done on Whisper.Env notes
Build your environment as such:
Usage
Benchmarks
Benchmarking code can be found here.
Reported results are best configuration (attention implementation, dtype) with compile vs. best without compile for generating 43 tokens (~ 0.5 sec of audio):
Tests
This PR has been tested for generation by comparing generation outputs for this branch and the one it was built on. Code for such tests can be found here.