Skip to content

Commit

Permalink
Architecture improvements (huggingface#65)
Browse files Browse the repository at this point in the history
* add RoPe

* don't include padding in rope

* possibly use cross-attn for prompt

* fix rope

* fix cross-attn

* fix self-attn

* fix dummy model

* clean-up rope

* first gqa implementation

* fix wer eval

* feat: add flash attention and spda

* chore: add README for flash attention

* chore: add benchmark script

* chore: add benchmark attention approach

* multi node and fix wer and fix compile

* Update modeling_parler_tts.py

* fix FA2, SDPA and add cross-attn MHA and attention type forcing

* better cross_attention key values number of heads default + add training arguments for attn implementation

* fix audio padding when torch compile or pad_to_max_length=True

* correct multi node

* make rope faster

* fix encoder sdpa

* fix training with cross attention + with FAZ

* use fp32 as default model dtype + fix generation when using FA2 with autocast

* remove redundant passes in generate + clean and fix attentions

* fix edge case in WER evaluation when longform generation

* better multi-node mapping and saving / add eval dataloader num workers

* remove old benchmarks

* faster audio encoding + checkpointing + fix generation step

* better eval + add right padding + fix eval loss compute

* correct README

* correct config docstrings

* remove comment

* make style

---------

Co-authored-by: sanchit-gandhi <[email protected]>
Co-authored-by: sang-nguyen-ts <[email protected]>
Co-authored-by: [email protected] <Yoach Lacombe>
  • Loading branch information
3 people authored Jul 31, 2024
1 parent 8b8c576 commit 11b209e
Show file tree
Hide file tree
Showing 12 changed files with 1,325 additions and 267 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ if torch.xpu.is_available():
device = "xpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32

model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1", torch_dtype=torch_dtype).to(device)

tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")

prompt = "Hey, how are you doing today?"
Expand Down
4 changes: 2 additions & 2 deletions helpers/model_init_scripts/init_dummy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0

model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1
model.config.decoder_start_token_id = encodec_vocab_size + 1

model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
3 changes: 3 additions & 0 deletions helpers/model_init_scripts/init_dummy_model_with_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,7 @@
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0

model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size + 1

model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
4 changes: 2 additions & 2 deletions helpers/model_init_scripts/init_model_600M.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0

model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1
model.config.decoder_start_token_id = encodec_vocab_size + 1

model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-600M/"))
1 change: 1 addition & 0 deletions helpers/push_to_hub_scripts/push_dac_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor

AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)

Expand Down
54 changes: 53 additions & 1 deletion parler_tts/configuration_parler_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Number of decoder layers.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer block.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
num_cross_attention_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
If it is not specified, will default to `num_key_value_heads`.
ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
Expand Down Expand Up @@ -74,6 +85,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
The number of parallel codebooks forwarded to the model.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether input and output word embeddings should be tied.
rope_embeddings (`bool`, *optional*, defaults to `False`):
Whether to use ROPE or absolute positional embeddings.
rope_theta (`float`, *optional*, defaults to 100000.0):
The base period of the RoPE embeddings.
cross_attention_implementation_strategy (`str`, *optional*):
If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
"""

model_type = "parler_tts_decoder"
Expand All @@ -86,6 +103,8 @@ def __init__(
num_hidden_layers=24,
ffn_dim=4096,
num_attention_heads=16,
num_key_value_heads=None,
num_cross_attention_key_value_heads=None,
layerdrop=0.0,
use_cache=True,
activation_function="gelu",
Expand All @@ -100,6 +119,9 @@ def __init__(
bos_token_id=2049,
eos_token_id=2048,
tie_word_embeddings=False,
rope_embeddings=False,
rope_theta=10_000.0,
cross_attention_implementation_strategy=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -108,6 +130,12 @@ def __init__(
self.ffn_dim = ffn_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if num_cross_attention_key_value_heads is None:
num_cross_attention_key_value_heads = num_key_value_heads
self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
Expand All @@ -117,6 +145,9 @@ def __init__(
self.use_cache = use_cache
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.num_codebooks = num_codebooks
self.rope_embeddings = rope_embeddings
self.rope_theta = rope_theta
self.cross_attention_implementation_strategy = cross_attention_implementation_strategy

super().__init__(
pad_token_id=pad_token_id,
Expand All @@ -140,6 +171,8 @@ class ParlerTTSConfig(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`.
prompt_cross_attention (`bool`, *optional*, defaults to `False`):
Whether to use cross-attention conditioning for the prompt (as well as the description).
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
Expand Down Expand Up @@ -190,7 +223,7 @@ class ParlerTTSConfig(PretrainedConfig):
model_type = "parler_tts"
is_composition = True

def __init__(self, vocab_size=1024, **kwargs):
def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
Expand All @@ -204,6 +237,7 @@ def __init__(self, vocab_size=1024, **kwargs):
decoder_config = kwargs.pop("decoder")

self.vocab_size = vocab_size
self.prompt_cross_attention = prompt_cross_attention
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = ParlerTTSDecoderConfig(**decoder_config)
Expand Down Expand Up @@ -236,3 +270,21 @@ def from_sub_models_config(
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate

# Copy from musicgen
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"

@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
Loading

0 comments on commit 11b209e

Please sign in to comment.