Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static cache #89

Merged
merged 69 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
20dfbff
add rope
sanchit-gandhi May 17, 2024
fd16411
don't include padding in rope
sanchit-gandhi May 17, 2024
1344de2
possibly use cross-attn for prompt
sanchit-gandhi May 17, 2024
e3fe843
fix rope
sanchit-gandhi May 18, 2024
4f27ff7
fix cross-attn
sanchit-gandhi May 18, 2024
c8887f2
fix self-attn
sanchit-gandhi May 18, 2024
ba0bfca
fix dummy model
sanchit-gandhi May 20, 2024
387de15
clean-up rope
sanchit-gandhi May 21, 2024
e0500d0
first gqa implementation
ylacombe May 23, 2024
3342e2e
fix wer eval
ylacombe May 23, 2024
f094f7b
Merge branch 'huggingface:main' into gqa-experiments
ylacombe May 23, 2024
b171d98
feat: add flash attention and spda
sang-nguyen-ts May 29, 2024
9e915aa
chore: add README for flash attention
sang-nguyen-ts May 29, 2024
1eeca66
chore: add benchmark script
sang-nguyen-ts May 29, 2024
69521dd
chore: add benchmark attention approach
sang-nguyen-ts May 29, 2024
7539000
multi node and fix wer and fix compile
May 30, 2024
76dbe87
Update modeling_parler_tts.py
ylacombe May 30, 2024
9d67863
Merge pull request #2 from sang-nguyen-ts/pef/flash-sdpa-attention
ylacombe May 30, 2024
cd4fcc1
Merge branch 'architecture-experiments' into cross-attn
ylacombe May 30, 2024
934f08c
Merge pull request #1 from sanchit-gandhi/cross-attn
ylacombe May 30, 2024
fc79c06
fix FA2, SDPA and add cross-attn MHA and attention type forcing
ylacombe May 31, 2024
7808285
better cross_attention key values number of heads default + add train…
ylacombe Jun 4, 2024
0ce0df2
fix audio padding when torch compile or pad_to_max_length=True
ylacombe Jun 4, 2024
8198fd9
correct multi node
Jun 5, 2024
9b48d0a
make rope faster
ylacombe Jun 5, 2024
54b56d9
fix encoder sdpa
ylacombe Jun 5, 2024
1da55fc
fix training with cross attention + with FAZ
ylacombe Jun 5, 2024
8f6047a
use fp32 as default model dtype + fix generation when using FA2 with …
ylacombe Jun 6, 2024
d056ca5
remove redundant passes in generate + clean and fix attentions
ylacombe Jun 6, 2024
7dfbbca
fix edge case in WER evaluation when longform generation
Jun 6, 2024
15edf7c
better multi-node mapping and saving / add eval dataloader num workers
Jun 7, 2024
ef40654
remove old benchmarks
Jun 7, 2024
954d8c5
faster audio encoding + checkpointing + fix generation step
Jun 21, 2024
25490f0
Merge branch 'dev' into architecture-experiments
ylacombe Jul 8, 2024
011855e
Merge pull request #82 from ylacombe/architecture-experiments
ylacombe Jul 8, 2024
f9c36ac
unpin trfms
eustlb Jul 8, 2024
e52a8f0
remove CFG
eustlb Jul 8, 2024
ccae5a9
imports and constants
eustlb Jul 22, 2024
a9f75d5
attention modifications to handle static cach
eustlb Jul 22, 2024
89e50d5
decoder layer modification to handle static cache
eustlb Jul 22, 2024
fb750fe
ParlerTTSPreTrainedModel modifs to handle static cache
eustlb Jul 22, 2024
41fa4fd
ParlerTTSDecoder modifs to handle static cache
eustlb Jul 22, 2024
5a484f8
ParlerTTSModel + ParlerTTSForCausalLM modfis
eustlb Jul 22, 2024
c5da07e
ParlerTTSForConditionalGeneration modifs
eustlb Jul 22, 2024
8c780ef
decoder_attention_mask for static cache
eustlb Jul 22, 2024
afa18a3
create inputs_embeds early to have a good cache initialization
eustlb Jul 22, 2024
45d0fbb
_get_cache method
eustlb Jul 22, 2024
054b751
init the cache
eustlb Jul 22, 2024
11b693f
ensure good device
eustlb Jul 22, 2024
6af19d6
pin tfrms version
eustlb Jul 22, 2024
682ca70
fix attention_mask FA2
eustlb Jul 24, 2024
024a354
remove unnecessary method
eustlb Jul 24, 2024
a097aa4
Update parler_tts/modeling_parler_tts.py
eustlb Jul 24, 2024
ecd06c1
Update parler_tts/modeling_parler_tts.py
eustlb Jul 24, 2024
ad25e2b
remove unnecessary imports
eustlb Jul 24, 2024
f29392b
replace the hardcoded cache_position with a more elegant approach
eustlb Jul 25, 2024
d08e4eb
make style
eustlb Jul 31, 2024
15d8e6a
Merge remote-tracking branch 'upstream/main' into add-static-cache
eustlb Jul 31, 2024
0e46d0b
unpin transformers
eustlb Aug 1, 2024
43764cd
pin transformers
eustlb Aug 1, 2024
b5e25f0
pin torch
eustlb Aug 1, 2024
441bafd
refactor + unpin torch
eustlb Aug 1, 2024
45ee62e
Update parler_tts/modeling_parler_tts.py
eustlb Aug 1, 2024
a294fb5
update training script to match 11b209e
eustlb Aug 1, 2024
824b183
Update parler_tts/modeling_parler_tts.py
eustlb Aug 1, 2024
66fdbde
ensure compatibility with trfms 4.43.3, changes taken from #31980 on …
eustlb Aug 1, 2024
dc21c87
fix input_ids_length
eustlb Aug 2, 2024
650f276
warning full attention mask creation
eustlb Aug 5, 2024
41edc2a
changes for training compatibility
eustlb Aug 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions helpers/gradio_demo/app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoTokenizer, set_seed

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed


device = "cuda:0" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -57,7 +58,7 @@ def gen_tts(text, description):
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
border-radius: 9999px !important;
width: 13rem;
margin-top: 10px;
margin-left: auto;
Expand Down
8 changes: 5 additions & 3 deletions helpers/model_init_scripts/init_dummy_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os

from transformers import AutoConfig

from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions helpers/model_init_scripts/init_dummy_model_with_encodec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os

from transformers import AutoConfig

from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
8 changes: 5 additions & 3 deletions helpers/model_init_scripts/init_model_600M.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os

from transformers import AutoConfig

from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions helpers/push_to_hub_scripts/push_dac_to_hub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dac
from transformers import AutoConfig, AutoModel, EncodecFeatureExtractor

from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from transformers import AutoFeatureExtractor, AutoTokenizer

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor


path = "TODO"
repo_id = "parler_tts_600M"
Expand Down
5 changes: 3 additions & 2 deletions parler_tts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
__version__ = "0.1"


from transformers import AutoConfig, AutoModel

from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel
from .modeling_parler_tts import (
ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)

from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel

AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
2 changes: 1 addition & 1 deletion parler_tts/dac_wrapper/configuration_dac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from transformers import PretrainedConfig
from typing import List


class DACConfig(PretrainedConfig):
Expand Down
11 changes: 5 additions & 6 deletions parler_tts/dac_wrapper/modeling_dac.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch

from dac.model import DAC
from transformers import PreTrainedModel
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput
from .configuration_dac import DACConfig
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput

from dac.model import DAC
from .configuration_dac import DACConfig


# model doesn't support batching yet
Expand Down Expand Up @@ -79,7 +78,7 @@ def encode(
)

for offset in range(0, input_length - step, stride):
mask = padding_mask[..., offset : offset + chunk_length].bool()
padding_mask[..., offset : offset + chunk_length].bool()
eustlb marked this conversation as resolved.
Show resolved Hide resolved
frame = audio_data[:, :, offset : offset + chunk_length]

scale = None
Expand Down Expand Up @@ -134,4 +133,4 @@ def decode(
return EncodecDecoderOutput(audio_values)

def forward(self, tensor):
raise ValueError(f"`DACModel.forward` not implemented yet")
raise ValueError("`DACModel.forward` not implemented yet")
Loading