Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static cache #89

Merged
merged 69 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
20dfbff
add rope
sanchit-gandhi May 17, 2024
fd16411
don't include padding in rope
sanchit-gandhi May 17, 2024
1344de2
possibly use cross-attn for prompt
sanchit-gandhi May 17, 2024
e3fe843
fix rope
sanchit-gandhi May 18, 2024
4f27ff7
fix cross-attn
sanchit-gandhi May 18, 2024
c8887f2
fix self-attn
sanchit-gandhi May 18, 2024
ba0bfca
fix dummy model
sanchit-gandhi May 20, 2024
387de15
clean-up rope
sanchit-gandhi May 21, 2024
e0500d0
first gqa implementation
ylacombe May 23, 2024
3342e2e
fix wer eval
ylacombe May 23, 2024
f094f7b
Merge branch 'huggingface:main' into gqa-experiments
ylacombe May 23, 2024
b171d98
feat: add flash attention and spda
sang-nguyen-ts May 29, 2024
9e915aa
chore: add README for flash attention
sang-nguyen-ts May 29, 2024
1eeca66
chore: add benchmark script
sang-nguyen-ts May 29, 2024
69521dd
chore: add benchmark attention approach
sang-nguyen-ts May 29, 2024
7539000
multi node and fix wer and fix compile
May 30, 2024
76dbe87
Update modeling_parler_tts.py
ylacombe May 30, 2024
9d67863
Merge pull request #2 from sang-nguyen-ts/pef/flash-sdpa-attention
ylacombe May 30, 2024
cd4fcc1
Merge branch 'architecture-experiments' into cross-attn
ylacombe May 30, 2024
934f08c
Merge pull request #1 from sanchit-gandhi/cross-attn
ylacombe May 30, 2024
fc79c06
fix FA2, SDPA and add cross-attn MHA and attention type forcing
ylacombe May 31, 2024
7808285
better cross_attention key values number of heads default + add train…
ylacombe Jun 4, 2024
0ce0df2
fix audio padding when torch compile or pad_to_max_length=True
ylacombe Jun 4, 2024
8198fd9
correct multi node
Jun 5, 2024
9b48d0a
make rope faster
ylacombe Jun 5, 2024
54b56d9
fix encoder sdpa
ylacombe Jun 5, 2024
1da55fc
fix training with cross attention + with FAZ
ylacombe Jun 5, 2024
8f6047a
use fp32 as default model dtype + fix generation when using FA2 with …
ylacombe Jun 6, 2024
d056ca5
remove redundant passes in generate + clean and fix attentions
ylacombe Jun 6, 2024
7dfbbca
fix edge case in WER evaluation when longform generation
Jun 6, 2024
15edf7c
better multi-node mapping and saving / add eval dataloader num workers
Jun 7, 2024
ef40654
remove old benchmarks
Jun 7, 2024
954d8c5
faster audio encoding + checkpointing + fix generation step
Jun 21, 2024
25490f0
Merge branch 'dev' into architecture-experiments
ylacombe Jul 8, 2024
011855e
Merge pull request #82 from ylacombe/architecture-experiments
ylacombe Jul 8, 2024
f9c36ac
unpin trfms
eustlb Jul 8, 2024
e52a8f0
remove CFG
eustlb Jul 8, 2024
ccae5a9
imports and constants
eustlb Jul 22, 2024
a9f75d5
attention modifications to handle static cach
eustlb Jul 22, 2024
89e50d5
decoder layer modification to handle static cache
eustlb Jul 22, 2024
fb750fe
ParlerTTSPreTrainedModel modifs to handle static cache
eustlb Jul 22, 2024
41fa4fd
ParlerTTSDecoder modifs to handle static cache
eustlb Jul 22, 2024
5a484f8
ParlerTTSModel + ParlerTTSForCausalLM modfis
eustlb Jul 22, 2024
c5da07e
ParlerTTSForConditionalGeneration modifs
eustlb Jul 22, 2024
8c780ef
decoder_attention_mask for static cache
eustlb Jul 22, 2024
afa18a3
create inputs_embeds early to have a good cache initialization
eustlb Jul 22, 2024
45d0fbb
_get_cache method
eustlb Jul 22, 2024
054b751
init the cache
eustlb Jul 22, 2024
11b693f
ensure good device
eustlb Jul 22, 2024
6af19d6
pin tfrms version
eustlb Jul 22, 2024
682ca70
fix attention_mask FA2
eustlb Jul 24, 2024
024a354
remove unnecessary method
eustlb Jul 24, 2024
a097aa4
Update parler_tts/modeling_parler_tts.py
eustlb Jul 24, 2024
ecd06c1
Update parler_tts/modeling_parler_tts.py
eustlb Jul 24, 2024
ad25e2b
remove unnecessary imports
eustlb Jul 24, 2024
f29392b
replace the hardcoded cache_position with a more elegant approach
eustlb Jul 25, 2024
d08e4eb
make style
eustlb Jul 31, 2024
15d8e6a
Merge remote-tracking branch 'upstream/main' into add-static-cache
eustlb Jul 31, 2024
0e46d0b
unpin transformers
eustlb Aug 1, 2024
43764cd
pin transformers
eustlb Aug 1, 2024
b5e25f0
pin torch
eustlb Aug 1, 2024
441bafd
refactor + unpin torch
eustlb Aug 1, 2024
45ee62e
Update parler_tts/modeling_parler_tts.py
eustlb Aug 1, 2024
a294fb5
update training script to match 11b209e
eustlb Aug 1, 2024
824b183
Update parler_tts/modeling_parler_tts.py
eustlb Aug 1, 2024
66fdbde
ensure compatibility with trfms 4.43.3, changes taken from #31980 on …
eustlb Aug 1, 2024
dc21c87
fix input_ids_length
eustlb Aug 2, 2024
650f276
warning full attention mask creation
eustlb Aug 5, 2024
41edc2a
changes for training compatibility
eustlb Aug 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ if torch.xpu.is_available():
torch_dtype = torch.float16 if device != "cpu" else torch.float32

model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)

# # Use with flash attention
eustlb marked this conversation as resolved.
Show resolved Hide resolved
# model = ParlerTTSForConditionalGeneration.from_pretrained(
# repo_id, attn_implementation="flash_attention_2", torch_dtype=torch.float16
# ).to(device, dtype=torch_dtype)


model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)


tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")

prompt = "Hey, how are you doing today?"
Expand Down
3 changes: 3 additions & 0 deletions helpers/model_init_scripts/init_dummy_model_with_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,7 @@
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0

model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size + 1

model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
54 changes: 53 additions & 1 deletion parler_tts/configuration_parler_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Number of decoder layers.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer block.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
num_cross_attention_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
If it is not specified, will default to `num_cross_attention_key_value_heads`.
ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
Expand Down Expand Up @@ -74,6 +85,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
The number of parallel codebooks forwarded to the model.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether input and output word embeddings should be tied.
rope_embeddings (`bool`, *optional*, defaults to `False`):
Whether to use ROPE or absolute positional embeddings.
rope_theta (`float`, *optional*, defaults to 100000.0):
The base period of the RoPE embeddings.
cross_attention_implementation_strategy (`str`, *optional*):
If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
"""

model_type = "parler_tts_decoder"
Expand All @@ -86,6 +103,8 @@ def __init__(
num_hidden_layers=24,
ffn_dim=4096,
num_attention_heads=16,
num_key_value_heads=None,
num_cross_attention_key_value_heads=None,
layerdrop=0.0,
use_cache=True,
activation_function="gelu",
Expand All @@ -100,6 +119,9 @@ def __init__(
bos_token_id=2049,
eos_token_id=2048,
tie_word_embeddings=False,
rope_embeddings=False,
rope_theta=10_000.0,
cross_attention_implementation_strategy=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -108,6 +130,12 @@ def __init__(
self.ffn_dim = ffn_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if num_cross_attention_key_value_heads is None:
num_cross_attention_key_value_heads = num_key_value_heads
self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
Expand All @@ -117,6 +145,9 @@ def __init__(
self.use_cache = use_cache
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.num_codebooks = num_codebooks
self.rope_embeddings = rope_embeddings
self.rope_theta = rope_theta
self.cross_attention_implementation_strategy = cross_attention_implementation_strategy

super().__init__(
pad_token_id=pad_token_id,
Expand All @@ -140,6 +171,8 @@ class ParlerTTSConfig(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`.
prompt_cross_attention (`bool`, *optional*, defaults to `False`):
Whether to use cross-attention conditioning for the prompt (as well as the description).
kwargs (*optional*):
Dictionary of keyword arguments. Notably:

Expand Down Expand Up @@ -190,7 +223,7 @@ class ParlerTTSConfig(PretrainedConfig):
model_type = "parler_tts"
is_composition = True

def __init__(self, vocab_size=1024, **kwargs):
def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
Expand All @@ -204,6 +237,7 @@ def __init__(self, vocab_size=1024, **kwargs):
decoder_config = kwargs.pop("decoder")

self.vocab_size = vocab_size
self.prompt_cross_attention = prompt_cross_attention
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = ParlerTTSDecoderConfig(**decoder_config)
Expand Down Expand Up @@ -236,3 +270,21 @@ def from_sub_models_config(
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate

# Copy from musicgen
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"

@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
Loading